Commit
·
74bc48e
0
Parent(s):
Initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +35 -0
- Dyna-1/LICENSE.txt +73 -0
- Dyna-1/README.md +115 -0
- Dyna-1/configs/af2.yml +31 -0
- Dyna-1/configs/baseline.yml +28 -0
- Dyna-1/configs/esm2.yml +28 -0
- Dyna-1/configs/esm3.yml +28 -0
- Dyna-1/data/dataloader.py +341 -0
- Dyna-1/data/vocab.py +104 -0
- Dyna-1/esm/__init__.py +2 -0
- Dyna-1/esm/data/ParentChildTreeFile.txt +0 -0
- Dyna-1/esm/data/entry_list_safety_29026.list +0 -0
- Dyna-1/esm/data/interpro_29026_to_keywords_58641.csv +0 -0
- Dyna-1/esm/data/keyword_idf_safety_filtered_58641.npy +0 -0
- Dyna-1/esm/data/keyword_vocabulary_safety_filtered_58641.txt +0 -0
- Dyna-1/esm/layers/attention.py +76 -0
- Dyna-1/esm/layers/blocks.py +153 -0
- Dyna-1/esm/layers/codebook.py +88 -0
- Dyna-1/esm/layers/ffn.py +29 -0
- Dyna-1/esm/layers/geom_attention.py +149 -0
- Dyna-1/esm/layers/regression_head.py +22 -0
- Dyna-1/esm/layers/rotary.py +221 -0
- Dyna-1/esm/layers/structure_proj.py +66 -0
- Dyna-1/esm/layers/transformer_stack.py +93 -0
- Dyna-1/esm/models/esm3.py +606 -0
- Dyna-1/esm/models/esmc.py +164 -0
- Dyna-1/esm/models/function_decoder.py +306 -0
- Dyna-1/esm/models/vqvae.py +440 -0
- Dyna-1/esm/pretrained.py +132 -0
- Dyna-1/esm/sdk/__init__.py +22 -0
- Dyna-1/esm/sdk/api.py +445 -0
- Dyna-1/esm/sdk/forge.py +580 -0
- Dyna-1/esm/sdk/sagemaker.py +110 -0
- Dyna-1/esm/tokenization/__init__.py +69 -0
- Dyna-1/esm/tokenization/function_tokenizer.py +429 -0
- Dyna-1/esm/tokenization/residue_tokenizer.py +236 -0
- Dyna-1/esm/tokenization/sasa_tokenizer.py +153 -0
- Dyna-1/esm/tokenization/sequence_tokenizer.py +89 -0
- Dyna-1/esm/tokenization/ss_tokenizer.py +125 -0
- Dyna-1/esm/tokenization/structure_tokenizer.py +83 -0
- Dyna-1/esm/tokenization/tokenizer_base.py +44 -0
- Dyna-1/esm/utils/constants/api.py +5 -0
- Dyna-1/esm/utils/constants/esm3.py +130 -0
- Dyna-1/esm/utils/constants/models.py +25 -0
- Dyna-1/esm/utils/constants/physics.py +5 -0
- Dyna-1/esm/utils/decoding.py +244 -0
- Dyna-1/esm/utils/encoding.py +246 -0
- Dyna-1/esm/utils/function/encode_decode.py +187 -0
- Dyna-1/esm/utils/function/interpro.py +178 -0
- Dyna-1/esm/utils/function/lsh.py +102 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
Dyna-1/LICENSE.txt
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Non-Commercial License Agreement.
|
2 |
+
|
3 |
+
Copyright (c) 2025 Brandeis University, Gina El Nesr, Hannah Wayment-Steele.
|
4 |
+
|
5 |
+
IMPORTANT: PLEASE READ THIS NON-COMMERCIAL LICENSE AGREEMENT ("AGREEMENT") CAREFULLY BEFORE USING THE SOFTWARE. BY DOWNLOADING, INSTALLING, OR USING THE SOFTWARE, YOU (THE "LICENSEE") AGREE TO BE BOUND BY THE TERMS AND CONDITIONS OF THIS AGREEMENT. IF YOU DO NOT AGREE TO THE TERMS AND CONDITIONS OF THIS AGREEMENT, DO NOT DOWNLOAD, INSTALL, OR USE THE SOFTWARE.
|
6 |
+
|
7 |
+
WHEREAS, Brandeis University (the "Licensor") are the licensor of the Dyna-1 model (the "Model");
|
8 |
+
|
9 |
+
DEFINITIONS, in addition to other terms defined elsewhere have the following meanings:
|
10 |
+
|
11 |
+
THE MODEL (the "Model" or the "Software") means the Dyna-1 released model as a combination of the Dyna-1 Model Code and the Dyna-1 Model Weights and all software, algorithms, machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code, and other elements of the foregoing distributed and made available by the Licensor on the GitHub Page (https://github.com/WaymentSteeleLab/Dyna-1) or other code expressly listed within the GitHub Page including but not limited to Google Colab Notebooks, each as may be updated and amended from time to time, whether in source or object form. Any instances of any source code or weights that depend explicitly on ESM-3 Models are licensed under the EvolutionaryScale Cambrian Non-Commercial License Agreement and subject to the terms that are more restrictive.
|
12 |
+
|
13 |
+
COMMERCIAL USE means any use or activity intended for or directed towards commercial advantage or monetary compensation, including without limitation, the development of any product, method, or service intended to be sold or made available for a fee.
|
14 |
+
|
15 |
+
COMMERCIAL ENTITY means any individual, entity, or organization involved in any capacity in Commercial Use. For the purposes of this Agreement, references to a Commercial Entity expressly exclude any universities, non-profit organizations, not-for-profit entities, research institutions, educational bodies, and government bodies.
|
16 |
+
|
17 |
+
NON-COMMERCIAL USE means use not leading to or directed towards commercial advantage or monetary compensations, or the facilitation of development of any product, method, or service to be sold or made available for a fee. Non-commercial use excludes any activities conducted by commercial entities, even if those activities themselves do not directly generate revenue. Examples of non-commercial use include academic research, personal projects, and use by non-profit organizations.
|
18 |
+
|
19 |
+
DERIVATIVE WORKS means any work, in source or object form, that is based on or derived from the Model.
|
20 |
+
|
21 |
+
OUTPUT DERIVATIVES means any outputs resulting from use of the Model. This includes, but is not limited to, any predictions that result from the Model.
|
22 |
+
|
23 |
+
NOW, THEREFORE, the Licensor and the Licensee hereto agree as follows:
|
24 |
+
|
25 |
+
GRANT OF LICENSE.
|
26 |
+
|
27 |
+
Non-Commercial Use. The Model is provided under the MIT License for non-commercial use. The MIT License is included below.
|
28 |
+
|
29 |
+
Commercial Use. If the Licensee is or represents a Commercial Entity and wishes to use the Model for Commercial Use, the Licensee must first contact the Licensor for a commercial license. Examples of Commercial Use include but are not limited to the:
|
30 |
+
|
31 |
+
Selling the Model, either as is or as part of a larger software package;
|
32 |
+
|
33 |
+
Using the Model as an interface or component of a service where customers pay to use it;
|
34 |
+
|
35 |
+
Providing the Model, either as is or as part of a larger software package;
|
36 |
+
|
37 |
+
Incorporating the Model, in part or in whole, into a any commercial operation or workflow that may lead to commercial use of a product, method, or service;
|
38 |
+
|
39 |
+
Use of any derivative works or output derivatives towards any commercial operation or workflow that may lead to commercial use.
|
40 |
+
|
41 |
+
RESTRICTIONS. The Licensee shall not:
|
42 |
+
|
43 |
+
Use, modify, or distribute the Model for commercial purposes without receiving a commercial license.
|
44 |
+
|
45 |
+
Sublicense, assign, or transfer the Model without Licensor's prior written consent.
|
46 |
+
|
47 |
+
Use the Model in any manner that infringes upon any third-party rights.
|
48 |
+
|
49 |
+
Use the Model in any way that is not expressly permitted by this Agreement.
|
50 |
+
|
51 |
+
OWNERSHIP. Licensor retains all right, title, and interest in and to the Model, including all intellectual property rights therein. No rights are granted to Licensee other than those expressly set forth in this Agreement.
|
52 |
+
|
53 |
+
TERM AND TERMINATION. This Agreement shall commence upon Licensee's initial use of the Model and continue unless terminated by either party upon providing written notice to the other party. Upon termination, Licensee shall immediately cease using the Model and destroy all copies thereof.
|
54 |
+
|
55 |
+
LIABILITY. Licensor shall not be liable for any direct, indirect, incidental, special, consequential, or exemplary damages, including but not limited to damages for loss of profits, goodwill, use, data, or other intangible losses, resulting from the use of the Model.
|
56 |
+
|
57 |
+
GOVERNING LAW. This Agreement shall be governed by and construed in accordance with the laws of the United States of America.
|
58 |
+
|
59 |
+
MIT LICENSE
|
60 |
+
|
61 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
62 |
+
|
63 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
64 |
+
|
65 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
66 |
+
|
67 |
+
By using the Software, Licensee acknowledges that they have read this Agreement, understood it, and agreed to be bound by its terms and conditions.
|
68 |
+
|
69 |
+
CONTACT INFORMATION.
|
70 |
+
For commercial license inquiries, please contact: Brandeis University Office of Technology Licensing at [email protected]
|
71 |
+
For technical inquiries, please contact: [email protected], [email protected]
|
72 |
+
|
73 |
+
[END OF AGREEMENT]
|
Dyna-1/README.md
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Dyna-1
|
2 |
+
[](https://python.org/downloads)
|
3 |
+
[](https://colab.research.google.com/github/WaymentSteeleLab/Dyna-1/blob/main/colab/Dyna_1.ipynb)
|
4 |
+
|
5 |
+

|
6 |
+
|
7 |
+
Dyna-1 is a model introduced in our paper, ["Learning millisecond protein dynamics from what is missing in NMR spectra"](https://www.biorxiv.org/content/10.1101/2025.03.19.642801v1).
|
8 |
+
|
9 |
+
Given a sequence and/or structure, Dyna-1 will predict the probability that each residue experiences micro-millisecond motions.
|
10 |
+
|
11 |
+
Dyna-1 was achieved using the `esm3-sm-open-v1` weights from ESM-3. Inference with this model is subject to the EvolutionaryScale Cambrian Non-Commercial License Agreement of the ESM-3 Model and requires read permission of the weights found [here](https://huggingface.co/EvolutionaryScale/esm3-sm-open-v1). We also make available an alternate version of Dyna-1 that uses ESM-2 embeddings; use of this model is subject to a Non-Commercial License Agreement.
|
12 |
+
|
13 |
+
To make Dyna-1 readily accessible for research purposes, we also provide a [Google Colab](https://colab.research.google.com/github/WaymentSteeleLab/Dyna-1/blob/main/colab/Dyna_1.ipynb).
|
14 |
+
|
15 |
+
We provide the curated datasets used to evaluate Dyna-1: 133 curated R1/R2/NOE datasets "RelaxDB" and 10 relaxation-dispersion Carr-Purcell-Meiboom-Gill datasets "RelaxDB-CPMG".
|
16 |
+
|
17 |
+
In this repository we provide:
|
18 |
+
* [Installation](#installation)
|
19 |
+
* [Inference code for Dyna-1](#inference)
|
20 |
+
* [PyMol visualization](#visualization)
|
21 |
+
* [RelaxDB and RelaxDB-CPMG datasets](#datasets)
|
22 |
+
* [Training code for Dyna-1](#training)
|
23 |
+
* [Citation](#citation)
|
24 |
+
* [Acknowledgements](#acknowledgement)
|
25 |
+
|
26 |
+
If you have any questions not covered here, please create an issue and we will get back to you ASAP.
|
27 |
+
|
28 |
+
# Installation
|
29 |
+
To run the scripts in this repository, we recommend using a conda environment. First, clone this repository. Then navigate to the root directory and run:
|
30 |
+
```
|
31 |
+
conda create -n dyna1 python=3.11
|
32 |
+
conda activate dyna1
|
33 |
+
```
|
34 |
+
This package requires PyTorch, ideally with GPU support. For more information, follow instructions from https://pytorch.org/get-started/locally/ for your system and CUDA version. We used PyTorch 2.5.0 with CUDA 12.4. To install all of the requirements:
|
35 |
+
```
|
36 |
+
pip install -r requirements.txt
|
37 |
+
```
|
38 |
+
Then, download the model weights and upload them to the `model/weights` folder. The weights can be found on 🤗HuggingFace at <a href='https://huggingface.co/gelnesr/Dyna-1'>gelnesr/Dyna-1</a>. More information on how to download them can be found <a href='https://github.com/gelnesr/Dyna-1-public/blob/main/model/weights/README.md'>here</a>.
|
39 |
+
|
40 |
+
# Inference
|
41 |
+
|
42 |
+
The best-performing Dyna-1 is based on ESM-3. To run this version, you will have to request access to the ESM-3 `esm3-sm-open-v1` weights at HuggingFace [here](https://huggingface.co/EvolutionaryScale/esm3-sm-open-v1). Follow the steps to agree to their License terms and receive your access token to the model weights.
|
43 |
+
|
44 |
+
> [!NOTE]
|
45 |
+
> If this is your first time requesting access to the ESM-3 weights, you may need to set up your access token. For more information on how to set up an SSH token, please consult <a href='https://huggingface.co/docs/hub/en/security-git-ssh'>this</a> tutorial. Alternatively, you can use the huggingface login prompt, which will prompt you for the access token each time you re-instantiate Dyna-1. This can be configured by adding the following code to the inference script: `from huggingface_hub import login; login()`
|
46 |
+
|
47 |
+
To run inference using our best-performing model, run:
|
48 |
+
|
49 |
+
```
|
50 |
+
python dyna1.py --pdb <PDB CODE or PATH> --chain <CHAIN> --name <NAME> --use_pdb_seq --write_to_pdb
|
51 |
+
```
|
52 |
+
|
53 |
+
We provide three options for running inference: sequence and structure input (best performance!), sequence only, or structure only. Additionally, we make it possible to test different sequences for the same backbone. Examples on how to run each of these modes can be found in the `scripts/` folder.
|
54 |
+
|
55 |
+
To output the probabilities to the input structure, make sure to pass the `--write_to_pdb` flag.
|
56 |
+
|
57 |
+
Alternatively, we also provide a version of Dyna-1 based on ESM-2. To run inference using this version of the model, run:
|
58 |
+
|
59 |
+
```
|
60 |
+
python dyna1-esm2.py --sequence <SEQUENCE> --name <NAME>
|
61 |
+
```
|
62 |
+
|
63 |
+
# Visualization
|
64 |
+
|
65 |
+
We visualize probabilities of exchange on protein structures using [PyMol](https://www.pymol.org). To re-create the putty visualization on your protein, import the pdb file into PyMol and copy-paste the following commands into the PyMol command line:
|
66 |
+
|
67 |
+
```
|
68 |
+
cartoon putty; set cartoon_putty_transform, 6; set cartoon_putty_radius, 0.25; set cartoon_putty_range, 0.1; set cartoon_putty_scale_max, 10
|
69 |
+
```
|
70 |
+
|
71 |
+
Annotated:
|
72 |
+
```
|
73 |
+
cartoon putty
|
74 |
+
set cartoon_putty_transform, 6 #scaled linear transformation
|
75 |
+
set cartoon_putty_radius, 0.25 # min radius for p=0
|
76 |
+
set cartoon_putty_range, 0.1 # min_radius / max_radius, sets max_radius=2.5
|
77 |
+
set cartoon_putty_scale_max, 10 #max_radius / min_radius
|
78 |
+
```
|
79 |
+
|
80 |
+
# Datasets
|
81 |
+
|
82 |
+
*RelaxDB* contains 133 R1/R2/NOE datasets curated from the [BMRB](https://bmrb.io/) and from literature.
|
83 |
+
|
84 |
+
*RelaxDB-CPMG* contains motion labels derived from 10 CPMG relaxation-dispersion datasets curated from literature.
|
85 |
+
|
86 |
+
These datasets are made available on 🤗HuggingFace at <a href='https://huggingface.co/datasets/gelnesr/RelaxDB'>datasets/gelnesr/RelaxDB</a>.
|
87 |
+
|
88 |
+
In this repo, you can find:
|
89 |
+
- data formatted for input into Dyna-1 is in `data/RelaxDB_pkls_22jan2025.zip`
|
90 |
+
- datasets in json format is in `data/RelaxDB_datasets/`
|
91 |
+
- demo notebooks for demo notebooks for visualizing and using datasets to evaluate model outputs in `analysis/`
|
92 |
+
|
93 |
+
# Training
|
94 |
+
|
95 |
+
Training code will be made available upon journal publication.
|
96 |
+
|
97 |
+
# Citation
|
98 |
+
|
99 |
+
If you are using our code, datasets, or model, please use the following citation:
|
100 |
+
```bibtex
|
101 |
+
@article {Dyna-1,
|
102 |
+
author = {Wayment-Steele, Hannah K. and El Nesr, Gina and Hettiarachchi, Ramith and Kariyawasam, Hasindu and Ovchinnikov, Sergey and Kern, Dorothee},
|
103 |
+
title = {Learning millisecond protein dynamics from what is missing in NMR spectra},
|
104 |
+
year = {2025},
|
105 |
+
doi = {10.1101/2025.03.19.642801},
|
106 |
+
journal = {bioRxiv}
|
107 |
+
}
|
108 |
+
```
|
109 |
+
# Acknowledgements
|
110 |
+
|
111 |
+
We would like to acknowledge the Evolutionary Scale Team for their contributions to the field with ESM-3. The code in `esm` is imported from `evolutionaryscale/esm` with all modifications identified and includes the associated LICENSE terms for the ESM-3 model.
|
112 |
+
|
113 |
+
We would also like to acknowledge the FAIR Team for their contributions to the field with ESM-2. The ESM-2 model is called using the HuggingFace API call.
|
114 |
+
|
115 |
+
We thank Katie Henzler-Wildman, Magnus Wolf-Watz, Elan Eisenmesser, J. Patrick Loria, Marcellus Ubbelink, George Lisi, Sam Butcher, and Nicolas Doucet for sharing data. We thank Martin Stone for sharing the Indiana Dynamics Database data his group curated in 2000.
|
Dyna-1/configs/af2.yml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: 'af2'
|
2 |
+
train:
|
3 |
+
epochs: 500
|
4 |
+
batchsize: 4
|
5 |
+
accum_steps: 4
|
6 |
+
lr: 0.00001
|
7 |
+
dropout: 0.1
|
8 |
+
num_workers: 1
|
9 |
+
model:
|
10 |
+
hidden_size: 128
|
11 |
+
res_count: 32
|
12 |
+
length: 400
|
13 |
+
nheads: 8
|
14 |
+
nlayers: 12
|
15 |
+
dir:
|
16 |
+
save_dir: '/path/to/save/dir'
|
17 |
+
data:
|
18 |
+
pair_rep: None
|
19 |
+
sample_clusters: True
|
20 |
+
relaxdb:
|
21 |
+
split: 'relaxdb'
|
22 |
+
crop_len: 400
|
23 |
+
type: 'rex'
|
24 |
+
cpmg:
|
25 |
+
split: 'relaxdb-cpmg'
|
26 |
+
crop_len: 400
|
27 |
+
type: 'cpmg'
|
28 |
+
wandb:
|
29 |
+
project: "New-Project"
|
30 |
+
team: 'my-project-team'
|
31 |
+
dir: '/wandb/log/dir'
|
Dyna-1/configs/baseline.yml
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: 'baseline'
|
2 |
+
train:
|
3 |
+
epochs: 500
|
4 |
+
batchsize: 16
|
5 |
+
accum_steps: 1
|
6 |
+
lr: 0.000001
|
7 |
+
dropout: 0.1
|
8 |
+
num_workers: 2
|
9 |
+
model:
|
10 |
+
nheads: 6
|
11 |
+
nlayers: 10
|
12 |
+
dir:
|
13 |
+
save_dir: '/path/to/save/dir'
|
14 |
+
data:
|
15 |
+
pair_rep: None
|
16 |
+
sample_clusters: True
|
17 |
+
relaxdb:
|
18 |
+
split: 'relaxdb'
|
19 |
+
crop_len: 367
|
20 |
+
type: 'rex'
|
21 |
+
cpmg:
|
22 |
+
split: 'relaxdb-cpmg'
|
23 |
+
crop_len: 367
|
24 |
+
type: 'cpmg'
|
25 |
+
wandb:
|
26 |
+
project: "New-Project"
|
27 |
+
team: 'my-project-team'
|
28 |
+
dir: '/wandb/log/dir'
|
Dyna-1/configs/esm2.yml
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: 'esm2_t30_150M_UR50D'
|
2 |
+
train:
|
3 |
+
epochs: 500
|
4 |
+
batchsize: 16
|
5 |
+
accum_steps: 1
|
6 |
+
lr: 0.000001
|
7 |
+
dropout: 0.1
|
8 |
+
num_workers: 2
|
9 |
+
model:
|
10 |
+
nheads: 8
|
11 |
+
nlayers: 12
|
12 |
+
dir:
|
13 |
+
save_dir: '/path/to/save/dir'
|
14 |
+
data:
|
15 |
+
pair_rep: None
|
16 |
+
sample_clusters: True
|
17 |
+
relaxdb:
|
18 |
+
split: 'relaxdb'
|
19 |
+
crop_len: 367
|
20 |
+
type: 'rex'
|
21 |
+
cpmg:
|
22 |
+
split: 'relaxdb-cpmg'
|
23 |
+
crop_len: 367
|
24 |
+
type: 'cpmg'
|
25 |
+
wandb:
|
26 |
+
project: "New-Project"
|
27 |
+
team: 'my-project-team'
|
28 |
+
dir: '/wandb/log/dir'
|
Dyna-1/configs/esm3.yml
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: 'esm3'
|
2 |
+
train:
|
3 |
+
epochs: 500
|
4 |
+
batchsize: 16
|
5 |
+
accum_steps: 1
|
6 |
+
lr: 0.000001
|
7 |
+
dropout: 0.1
|
8 |
+
num_workers: 2
|
9 |
+
model:
|
10 |
+
nheads: 6
|
11 |
+
nlayers: 12
|
12 |
+
dir:
|
13 |
+
save_dir: '/path/to/save/dir'
|
14 |
+
data:
|
15 |
+
pair_rep: None
|
16 |
+
sample_clusters: True
|
17 |
+
relaxdb:
|
18 |
+
split: 'relaxdb'
|
19 |
+
crop_len: 367
|
20 |
+
type: 'rex'
|
21 |
+
cpmg:
|
22 |
+
split: 'relaxdb-cpmg'
|
23 |
+
crop_len: 367
|
24 |
+
type: 'cpmg'
|
25 |
+
wandb:
|
26 |
+
project: "New-Project"
|
27 |
+
team: 'my-project-team'
|
28 |
+
dir: '/wandb/log/dir'
|
Dyna-1/data/dataloader.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import warnings
|
3 |
+
import random
|
4 |
+
import torch
|
5 |
+
import os
|
6 |
+
import pickle
|
7 |
+
import data.vocab as vocab
|
8 |
+
import pandas as pd
|
9 |
+
from typing import Tuple, List, Any
|
10 |
+
from esm.sdk.api import ESMProtein
|
11 |
+
from transformers import AutoTokenizer
|
12 |
+
import utils
|
13 |
+
|
14 |
+
class DynaData(torch.utils.data.Dataset):
|
15 |
+
"""
|
16 |
+
For each protein, we use a pkl file that contains:
|
17 |
+
seq: The domain sequence, stored as an L-length string
|
18 |
+
assns: string containing labels of dynamics type
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
split,
|
24 |
+
type = 'missing',
|
25 |
+
sample_clusters = False,
|
26 |
+
cluster_file = None,
|
27 |
+
crop_len = 300,
|
28 |
+
missing_only = False,
|
29 |
+
rex_only = False,
|
30 |
+
unsuppressed = False,
|
31 |
+
method = None,
|
32 |
+
pair_rep = None,
|
33 |
+
return_dssp = False
|
34 |
+
):
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
self.return_dssp = return_dssp
|
38 |
+
self.crop_len = crop_len
|
39 |
+
self.sample_clusters = sample_clusters
|
40 |
+
self.label_tokenizer = vocab.label_tokenizer(type = type,
|
41 |
+
missing_only = missing_only,
|
42 |
+
rex_only = rex_only,
|
43 |
+
unsuppressed = unsuppressed)
|
44 |
+
|
45 |
+
# tokenization is the same for all ESM models, use the lightest one
|
46 |
+
self.tokenizer = AutoTokenizer.from_pretrained(f"facebook/esm2_t6_8M_UR50D")
|
47 |
+
self.proline = self.tokenizer.get_vocab()['P']
|
48 |
+
|
49 |
+
self.method = method[0]
|
50 |
+
self.model = method[1]
|
51 |
+
|
52 |
+
if isinstance(split, str):
|
53 |
+
split = [split]
|
54 |
+
|
55 |
+
self.all_names, self.names = [], []
|
56 |
+
|
57 |
+
# read in all pdb names
|
58 |
+
for fil in split:
|
59 |
+
filename = f'data/split_files/{fil}.txt'
|
60 |
+
with open(filename,'r') as f:
|
61 |
+
self.all_names.extend(f.read().splitlines())
|
62 |
+
|
63 |
+
# set up cluster sampling
|
64 |
+
if self.sample_clusters:
|
65 |
+
self.cluster_info = pd.read_csv(f'data/{cluster_file}.tsv', sep='\t')
|
66 |
+
self.cluster_info['cluster'] = self.cluster_info.apply(lambda row: row['cluster'], axis=1)
|
67 |
+
for nm in self.all_names:
|
68 |
+
subset = self.cluster_info.loc[self.cluster_info.entry_ID==nm]
|
69 |
+
if len(subset) == 0:
|
70 |
+
print('NO!', nm)
|
71 |
+
cluster_ind = subset['cluster'].iloc[0]
|
72 |
+
if cluster_ind not in self.names:
|
73 |
+
self.names.append(cluster_ind)
|
74 |
+
else:
|
75 |
+
self.names = self.all_names
|
76 |
+
|
77 |
+
self.pair_rep_dir = pair_rep
|
78 |
+
|
79 |
+
def __len__(self):
|
80 |
+
return len(self.names)
|
81 |
+
|
82 |
+
def __baseline_get_item__(self, name, obj, crop_start):
|
83 |
+
if crop_start > -1:
|
84 |
+
sequence = obj['sequence'][crop_start:crop_start+self.crop_len]
|
85 |
+
else:
|
86 |
+
sequence = obj['sequence'][:self.crop_len]
|
87 |
+
|
88 |
+
sequence_tokens = self.tokenizer.encode(sequence,
|
89 |
+
add_special_tokens=False,
|
90 |
+
padding='max_length',
|
91 |
+
max_length=self.crop_len,
|
92 |
+
return_tensors='np').squeeze()
|
93 |
+
|
94 |
+
# set mask to 1 for length of seq, padded tokens then are 0
|
95 |
+
eval_mask = np.zeros_like(sequence_tokens)
|
96 |
+
eval_mask[:len(sequence)] = 1
|
97 |
+
|
98 |
+
sequence_id = sequence_tokens != 0
|
99 |
+
|
100 |
+
# mask prolines in eval
|
101 |
+
eval_mask[sequence_tokens==self.proline] = 0
|
102 |
+
|
103 |
+
return sequence_tokens, sequence_id, eval_mask
|
104 |
+
|
105 |
+
def __af2_get_item__(self, name, obj, crop_start):
|
106 |
+
"""
|
107 |
+
Prepares input for the AF2-pair model
|
108 |
+
"""
|
109 |
+
|
110 |
+
pair_rep = np.load(f"{self.pair_rep_dir}/{name}.npy")
|
111 |
+
labels, seq = obj['label'], obj['sequence']
|
112 |
+
|
113 |
+
if crop_start > -1:
|
114 |
+
pair_rep = pair_rep[crop_start:crop_start+self.crop_len,
|
115 |
+
crop_start:crop_start+self.crop_len, :]
|
116 |
+
labels = labels[crop_start:crop_start+self.crop_len]
|
117 |
+
seq = seq[crop_start:crop_start+self.crop_len]
|
118 |
+
|
119 |
+
eval_mask = np.zeros((pair_rep.shape[0],))
|
120 |
+
|
121 |
+
prolines = [i for i, aa in enumerate(seq) if aa == 'P']
|
122 |
+
eval_mask[:len(labels)] = 1
|
123 |
+
|
124 |
+
sequence_id = eval_mask != 0
|
125 |
+
eval_mask[prolines] = 0
|
126 |
+
x = pair_rep.shape[0]
|
127 |
+
|
128 |
+
eval_mask = np.pad(eval_mask, (0, self.crop_len - len(eval_mask)), mode='constant')
|
129 |
+
sequence_id = np.pad(sequence_id, (0, self.crop_len - len(sequence_id)), mode='constant')
|
130 |
+
if x < self.crop_len:
|
131 |
+
pair_rep = np.pad(pair_rep, ((0, self.crop_len - x), (0, self.crop_len - x), (0, 0)), mode='constant', constant_values=0)
|
132 |
+
|
133 |
+
return pair_rep, sequence_id, eval_mask
|
134 |
+
|
135 |
+
def __esm3_get_item__(self, name, crop_start, data_path = 'esm3_data/'):
|
136 |
+
"""
|
137 |
+
Prepares input for the ESM3 model
|
138 |
+
"""
|
139 |
+
pkl_fname = os.path.join(data_path, f"{name}.pkl")
|
140 |
+
|
141 |
+
try:
|
142 |
+
with open(pkl_fname, "rb") as f:
|
143 |
+
esm_data = pickle.load(f)
|
144 |
+
except:
|
145 |
+
print(f'writing pkl for {name} {crop_start}')
|
146 |
+
pdb_path = f'pdbs/{name}.pdb'
|
147 |
+
protein = ESMProtein.from_pdb(pdb_path)
|
148 |
+
|
149 |
+
self.model.eval()
|
150 |
+
encoder = self.model.model.encode(protein)
|
151 |
+
self.model.train()
|
152 |
+
|
153 |
+
seq = encoder.sequence.cpu().detach()[1:-1][:700]
|
154 |
+
struct = encoder.structure.cpu().detach()[1:-1][:700]
|
155 |
+
|
156 |
+
sequence_tokens = np.full(700, 1, dtype=np.int32) ## sequence pad token is 1
|
157 |
+
structure_tokens = np.full(700, 4099, dtype=np.int32) ## structure pad token is 4099
|
158 |
+
|
159 |
+
sequence_tokens[:len(seq)] = seq
|
160 |
+
structure_tokens[:len(struct)] = struct
|
161 |
+
|
162 |
+
sequence_id = sequence_tokens != 1
|
163 |
+
|
164 |
+
obj ={'name': name, 'len': len(seq), 'seq_tokens': sequence_tokens,
|
165 |
+
'struct_tokens': structure_tokens, 'sequence_id': sequence_id}
|
166 |
+
|
167 |
+
with open(pkl_fname, 'wb') as f:
|
168 |
+
pickle.dump(obj, f)
|
169 |
+
|
170 |
+
with open(pkl_fname, "rb") as f:
|
171 |
+
esm_data = pickle.load(f)
|
172 |
+
|
173 |
+
if crop_start > -1:
|
174 |
+
sequence_tokens = esm_data['seq_tokens'][crop_start:crop_start+self.crop_len]
|
175 |
+
structure_tokens = esm_data['struct_tokens'][crop_start:crop_start+self.crop_len]
|
176 |
+
sequence_id = esm_data['sequence_id'][crop_start:crop_start+self.crop_len]
|
177 |
+
else:
|
178 |
+
sequence_tokens = esm_data['seq_tokens'][:self.crop_len]
|
179 |
+
structure_tokens = esm_data['struct_tokens'][:self.crop_len]
|
180 |
+
sequence_id = esm_data['sequence_id'][:self.crop_len]
|
181 |
+
|
182 |
+
eval_mask = np.zeros_like(sequence_tokens)
|
183 |
+
eval_mask[:esm_data['len']] = 1
|
184 |
+
eval_mask[sequence_tokens==self.proline] = 0
|
185 |
+
|
186 |
+
return sequence_tokens, structure_tokens, sequence_id, eval_mask
|
187 |
+
|
188 |
+
def __esm2_get_item__(self, obj, crop_start):
|
189 |
+
"""
|
190 |
+
Prepares input for the ESM2 model
|
191 |
+
"""
|
192 |
+
sequence = obj['sequence'].replace(' ','')
|
193 |
+
if crop_start > -1:
|
194 |
+
sequence = sequence[crop_start:crop_start+self.crop_len]
|
195 |
+
|
196 |
+
sequence_tokens = self.tokenizer.encode(sequence,
|
197 |
+
add_special_tokens=False,
|
198 |
+
padding='max_length',
|
199 |
+
max_length=self.crop_len,
|
200 |
+
return_tensors='np').squeeze()
|
201 |
+
|
202 |
+
# Set mask to 1 for length of sequence, padded tokens then are 0
|
203 |
+
eval_mask = np.zeros_like(sequence_tokens)
|
204 |
+
eval_mask[:len(sequence)] = 1
|
205 |
+
sequence_id = eval_mask != 0
|
206 |
+
eval_mask[sequence_tokens==self.proline] = 0
|
207 |
+
|
208 |
+
return sequence_tokens, sequence_id, eval_mask
|
209 |
+
|
210 |
+
def __get_dssp__(self, name, crop_start):
|
211 |
+
"""
|
212 |
+
Prepares DSSP information for a given sequence
|
213 |
+
"""
|
214 |
+
try:
|
215 |
+
dssp_csv = pd.read_csv('data/dssp.csv')
|
216 |
+
entry = dssp_csv.loc[dssp_csv.PDB == str(name)].iloc[0]
|
217 |
+
except:
|
218 |
+
entry = {}
|
219 |
+
entry['DSSP'] = utils.calc_dssp(f'pdbs/{name}.pdb')
|
220 |
+
|
221 |
+
assert len(entry) > 0
|
222 |
+
if crop_start ==-1:
|
223 |
+
dssp_data = entry['DSSP'].replace(' ','')[:self.crop_len]
|
224 |
+
else:
|
225 |
+
dssp_data = entry['DSSP'].replace(' ','')[crop_start:crop_start + self.crop_len]
|
226 |
+
|
227 |
+
dssp = np.zeros(self.crop_len)
|
228 |
+
inds = [i for i, char in enumerate(dssp_data) if char=='C']
|
229 |
+
dssp[inds] = 1.0
|
230 |
+
|
231 |
+
inds = [i for i, char in enumerate(dssp_data) if char=='H']
|
232 |
+
dssp[inds] = 2.0
|
233 |
+
|
234 |
+
return dssp
|
235 |
+
|
236 |
+
def __getitem__(self, idx):
|
237 |
+
"""
|
238 |
+
Returns a dict with the appropriate entries for each model
|
239 |
+
"""
|
240 |
+
exists = -1
|
241 |
+
while exists == -1:
|
242 |
+
name = self.names[idx]
|
243 |
+
if self.sample_clusters:
|
244 |
+
roptions = list(self.cluster_info.loc[self.cluster_info.cluster==name]['entry_ID'].values)
|
245 |
+
options = [opt for opt in roptions if opt in self.all_names]
|
246 |
+
name = random.choice(options)
|
247 |
+
pkl_fname = f"data/mBMRB_data/{name}.pkl"
|
248 |
+
|
249 |
+
try:
|
250 |
+
with open(pkl_fname, "rb") as f:
|
251 |
+
obj = pickle.load(f)
|
252 |
+
exists = 1
|
253 |
+
except:
|
254 |
+
print(f'{pkl_fname} not found')
|
255 |
+
|
256 |
+
assns = obj['label']
|
257 |
+
assns = vocab.mask_termini(assns)
|
258 |
+
|
259 |
+
crop_start = -1
|
260 |
+
if len(assns) > self.crop_len:
|
261 |
+
crop_start = np.random.choice(range(0, len(assns)-self.crop_len))
|
262 |
+
assns = assns[crop_start:crop_start + self.crop_len]
|
263 |
+
|
264 |
+
labels = self.label_tokenizer.convert_tokens_to_ids(assns, pad_to_length=self.crop_len)
|
265 |
+
labels = np.asarray(labels, np.int64)
|
266 |
+
|
267 |
+
dssp = None
|
268 |
+
if self.return_dssp:
|
269 |
+
dssp = self.__get_dssp__(name, crop_start)
|
270 |
+
if 'esm3' in self.method:
|
271 |
+
sequence, structure, sequence_id, eval_mask = self.__esm3_get_item__(name, crop_start)
|
272 |
+
elif 'esm2' in self.method:
|
273 |
+
sequence, sequence_id, eval_mask = self.__esm2_get_item__(obj, crop_start)
|
274 |
+
elif 'af2' in self.method:
|
275 |
+
pair_rep, sequence_id, eval_mask = self.__af2_get_item__(name, obj, crop_start)
|
276 |
+
elif 'baseline' in self.method:
|
277 |
+
sequence, sequence_id, eval_mask = self.__baseline_get_item__(name, obj, crop_start)
|
278 |
+
|
279 |
+
# Mask termini for eval. A -1 label corresponds to indices that are getting masked in vocabs
|
280 |
+
eval_mask[labels==-1] = 0
|
281 |
+
|
282 |
+
if 'esm2' in self.method:
|
283 |
+
return sequence, sequence_id, eval_mask, labels, name, dssp
|
284 |
+
elif 'esm3' in self.method:
|
285 |
+
return sequence, structure, sequence_id, eval_mask, labels, name, dssp
|
286 |
+
elif 'af2' in self.method:
|
287 |
+
return pair_rep, labels, sequence_id, eval_mask, name, dssp
|
288 |
+
elif 'baseline' in self.method:
|
289 |
+
return sequence, labels, sequence_id, eval_mask, name, dssp
|
290 |
+
|
291 |
+
def __collate_fn__(self, batch: List[Tuple[Any, ...]]):
|
292 |
+
|
293 |
+
if 'baseline' in self.method:
|
294 |
+
seqs, labels, sequence_id, eval_mask, names, dssp = tuple(zip(*batch))
|
295 |
+
seqs = torch.tensor(np.array(seqs))
|
296 |
+
|
297 |
+
labels = torch.from_numpy(np.array(labels)).float()
|
298 |
+
eval_mask = torch.from_numpy(np.array(eval_mask))
|
299 |
+
if self.return_dssp:
|
300 |
+
dssp = torch.from_numpy(np.array(dssp))
|
301 |
+
sequence_id = torch.from_numpy(np.array(sequence_id))
|
302 |
+
|
303 |
+
output = {'names': names, 'seqs': seqs, 'seq_id': sequence_id,
|
304 |
+
'targets': labels, 'eval_mask': eval_mask, 'dssp': dssp}
|
305 |
+
return output
|
306 |
+
|
307 |
+
elif 'af2' in self.method:
|
308 |
+
pair_reps, labels, sequence_id, eval_mask, names, dssp = tuple(zip(*batch))
|
309 |
+
|
310 |
+
pair_reps = torch.from_numpy(np.array(pair_reps, dtype=np.float64))
|
311 |
+
labels = torch.from_numpy(np.array(labels)).float()
|
312 |
+
eval_mask = torch.from_numpy(np.array(eval_mask))
|
313 |
+
sequence_id = torch.from_numpy(np.array(sequence_id, dtype=bool))
|
314 |
+
|
315 |
+
if self.return_dssp:
|
316 |
+
dssp = torch.from_numpy(np.array(dssp))
|
317 |
+
output = {'names': names, 'pair_reps': pair_reps, 'targets': labels, "seq_id": sequence_id,
|
318 |
+
'eval_mask': eval_mask, 'dssp': dssp}
|
319 |
+
return output
|
320 |
+
|
321 |
+
elif 'esm2' in self.method:
|
322 |
+
seqs, sequence_id, eval_mask, label, names, dssp = tuple(zip(*batch))
|
323 |
+
seqs = torch.from_numpy(np.array(seqs))
|
324 |
+
structs = None
|
325 |
+
sequence_id = torch.from_numpy(np.array(sequence_id))
|
326 |
+
|
327 |
+
elif 'esm3' in self.method:
|
328 |
+
seqs, structs, sequence_id, eval_mask, label, names, dssp = tuple(zip(*batch))
|
329 |
+
seqs = torch.from_numpy(np.array(seqs))
|
330 |
+
structs = torch.from_numpy(np.array(structs))
|
331 |
+
sequence_id = torch.from_numpy(np.array(sequence_id))
|
332 |
+
|
333 |
+
eval_mask = torch.from_numpy(np.array(eval_mask))
|
334 |
+
label = torch.from_numpy(np.array(label)).float()
|
335 |
+
if self.return_dssp:
|
336 |
+
dssp = torch.from_numpy(np.array(dssp))
|
337 |
+
|
338 |
+
output = {'seqs': seqs, "structs": structs, "seq_id": sequence_id, "eval_mask": eval_mask,
|
339 |
+
'targets': label, 'names': names, 'dssp': dssp}
|
340 |
+
|
341 |
+
return output
|
Dyna-1/data/vocab.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
VOCAB_MISSING = OrderedDict([
|
5 |
+
('t',-1), # no data bc disorderd terminus
|
6 |
+
('x',-1), # no data, R1/R2/NOE not reported
|
7 |
+
('p',0), # proline, not evaluated
|
8 |
+
('A', 0), # nothing
|
9 |
+
('v',0), # fast motion
|
10 |
+
('.',1), # missing
|
11 |
+
('b',0), # both fast and slow
|
12 |
+
('^',0) # rex
|
13 |
+
])
|
14 |
+
|
15 |
+
VOCAB_REX = OrderedDict([
|
16 |
+
('t',-1), # no data bc disorderd terminus
|
17 |
+
('x',-1), # no data, R1/R2/NOE not reported
|
18 |
+
('p',0), # proline, not evaluated
|
19 |
+
('A', 0), # nothing
|
20 |
+
('v',0), # fast motion
|
21 |
+
('.',1), # missing
|
22 |
+
('b',1), # both fast and slow
|
23 |
+
('^',1) # rex
|
24 |
+
])
|
25 |
+
|
26 |
+
VOCAB_CPMG = OrderedDict([
|
27 |
+
('t',-1), # no data bc disordered terminus
|
28 |
+
('P',0), # proline, not evaluated
|
29 |
+
('N',-1), # no data, assned but CPMG not reported
|
30 |
+
('A', 0), # nothing
|
31 |
+
('.',1), # missing
|
32 |
+
('X',1), # exchange from Rex definition
|
33 |
+
('Y',0) # exchange from unsuppressed R2
|
34 |
+
])
|
35 |
+
|
36 |
+
def mask_termini(seq):
|
37 |
+
"""
|
38 |
+
Mask the termini of a sequence
|
39 |
+
"""
|
40 |
+
seq = seq.lstrip('.p').rjust(len(seq), 't')
|
41 |
+
seq = seq.rstrip('.p').ljust(len(seq), 't')
|
42 |
+
return seq
|
43 |
+
|
44 |
+
class label_tokenizer():
|
45 |
+
def __init__(self,
|
46 |
+
type = 'missing',
|
47 |
+
missing_only = False,
|
48 |
+
rex_only = False,
|
49 |
+
unsuppressed = False):
|
50 |
+
"""
|
51 |
+
Tokenize the data labeling for BMRB, REX, and CPMG
|
52 |
+
|
53 |
+
Args:
|
54 |
+
type: (str) which type of experiment
|
55 |
+
missing_only: (bool) only return residues with missing peaks
|
56 |
+
rex_only: (bool) only return residues with Rex
|
57 |
+
unsuppressed: (bool) return residues with unsuppressed Rex
|
58 |
+
"""
|
59 |
+
if type == 'missing':
|
60 |
+
self.vocab = VOCAB_MISSING.copy()
|
61 |
+
elif type == 'rex':
|
62 |
+
self.vocab = VOCAB_REX.copy()
|
63 |
+
if missing_only:
|
64 |
+
self.vocab['b'] = 0
|
65 |
+
self.vocab['^'] = 0
|
66 |
+
elif type == 'cpmg':
|
67 |
+
self.vocab = VOCAB_CPMG.copy()
|
68 |
+
if missing_only:
|
69 |
+
self.vocab['X'] = 0
|
70 |
+
if unsuppressed:
|
71 |
+
self.vocab['Y'] = 1
|
72 |
+
if rex_only:
|
73 |
+
self.vocab['.'] = -1
|
74 |
+
self.tokens = list(self.vocab.keys())
|
75 |
+
|
76 |
+
@property
|
77 |
+
def vocab_size(self) -> int:
|
78 |
+
return len(self.vocab)
|
79 |
+
|
80 |
+
def convert_token_to_id(self, token: str) -> int:
|
81 |
+
""" Converts a token (str/unicode) in an id using the vocab. """
|
82 |
+
try:
|
83 |
+
return self.vocab[token]
|
84 |
+
except KeyError:
|
85 |
+
raise KeyError(f"Unrecognized token: '{token}'")
|
86 |
+
|
87 |
+
def convert_tokens_to_ids(self, tokens: List[str], pad_to_length = None) -> List[int]:
|
88 |
+
"""Converts a list of tokens (str/unicode) into a list of ids (int) using the vocab. """
|
89 |
+
|
90 |
+
if pad_to_length is None:
|
91 |
+
return [self.convert_token_to_id(token) for token in tokens]
|
92 |
+
else:
|
93 |
+
return [self.convert_token_to_id(token) for token in tokens] + [0] * (pad_to_length - len(tokens))
|
94 |
+
|
95 |
+
def convert_id_to_token(self, index: int) -> str:
|
96 |
+
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
97 |
+
try:
|
98 |
+
return self.tokens[index]
|
99 |
+
except IndexError:
|
100 |
+
raise IndexError(f"Unrecognized index: '{index}'")
|
101 |
+
|
102 |
+
def convert_ids_to_tokens(self, indices: List[int]) -> List[str]:
|
103 |
+
"""Converts a list of indices (integer) into a list of tokens (string/unicode) using the vocab."""
|
104 |
+
return [self.convert_id_to_token(id_) for id_ in indices]
|
Dyna-1/esm/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
__version__ = "3.1.1"
|
2 |
+
|
Dyna-1/esm/data/ParentChildTreeFile.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Dyna-1/esm/data/entry_list_safety_29026.list
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Dyna-1/esm/data/interpro_29026_to_keywords_58641.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Dyna-1/esm/data/keyword_idf_safety_filtered_58641.npy
ADDED
Binary file (469 kB). View file
|
|
Dyna-1/esm/data/keyword_vocabulary_safety_filtered_58641.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Dyna-1/esm/layers/attention.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
import einops
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
from esm.layers.rotary import RotaryEmbedding
|
9 |
+
|
10 |
+
|
11 |
+
class MultiHeadAttention(nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self, d_model: int, n_heads: int, bias: bool = False, qk_layernorm: bool = True
|
14 |
+
):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
self.d_model = d_model
|
18 |
+
self.n_heads = n_heads
|
19 |
+
|
20 |
+
self.d_head = self.d_model // self.n_heads
|
21 |
+
self.layernorm_qkv = nn.Sequential(
|
22 |
+
nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=bias)
|
23 |
+
)
|
24 |
+
self.out_proj = nn.Linear(d_model, d_model, bias=bias)
|
25 |
+
|
26 |
+
if qk_layernorm:
|
27 |
+
self.q_ln = nn.LayerNorm(d_model, bias=bias)
|
28 |
+
self.k_ln = nn.LayerNorm(d_model, bias=bias)
|
29 |
+
else:
|
30 |
+
self.q_ln = nn.Identity()
|
31 |
+
self.k_ln = nn.Identity()
|
32 |
+
|
33 |
+
self.rotary = RotaryEmbedding(d_model // n_heads)
|
34 |
+
|
35 |
+
def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor):
|
36 |
+
q = q.unflatten(-1, (self.n_heads, self.d_head))
|
37 |
+
k = k.unflatten(-1, (self.n_heads, self.d_head))
|
38 |
+
q, k = self.rotary(q, k)
|
39 |
+
q = q.flatten(-2, -1)
|
40 |
+
k = k.flatten(-2, -1)
|
41 |
+
return q, k
|
42 |
+
|
43 |
+
def forward(self, x, seq_id):
|
44 |
+
qkv_BLD3 = self.layernorm_qkv(x)
|
45 |
+
query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
|
46 |
+
query_BLD, key_BLD = (
|
47 |
+
self.q_ln(query_BLD).to(query_BLD.dtype),
|
48 |
+
self.k_ln(key_BLD).to(query_BLD.dtype),
|
49 |
+
)
|
50 |
+
query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)
|
51 |
+
|
52 |
+
n_heads = self.n_heads
|
53 |
+
reshaper = functools.partial(
|
54 |
+
einops.rearrange, pattern="b s (h d) -> b h s d", h=n_heads
|
55 |
+
)
|
56 |
+
|
57 |
+
query_BHLD, key_BHLD, value_BHLD = map(
|
58 |
+
reshaper, (query_BLD, key_BLD, value_BLD)
|
59 |
+
)
|
60 |
+
|
61 |
+
if seq_id is not None:
|
62 |
+
# Where True, enable participation in attention.
|
63 |
+
mask_BLL = seq_id.unsqueeze(-1) == seq_id.unsqueeze(-2)
|
64 |
+
mask_BHLL = mask_BLL.unsqueeze(1)
|
65 |
+
|
66 |
+
context_BHLD = F.scaled_dot_product_attention(
|
67 |
+
query_BHLD, key_BHLD, value_BHLD, mask_BHLL
|
68 |
+
)
|
69 |
+
else:
|
70 |
+
# Shortcut, if we don't use attention biases then torch
|
71 |
+
# will autoselect flashattention as the implementation
|
72 |
+
context_BHLD = F.scaled_dot_product_attention(
|
73 |
+
query_BHLD, key_BHLD, value_BHLD
|
74 |
+
)
|
75 |
+
context_BLD = einops.rearrange(context_BHLD, "b h s d -> b s (h d)")
|
76 |
+
return self.out_proj(context_BLD)
|
Dyna-1/esm/layers/blocks.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from esm.layers.attention import MultiHeadAttention
|
6 |
+
from esm.layers.geom_attention import (
|
7 |
+
GeometricReasoningOriginalImpl,
|
8 |
+
)
|
9 |
+
from esm.utils.structure.affine3d import Affine3D
|
10 |
+
|
11 |
+
|
12 |
+
def swiglu_correction_fn(expansion_ratio: float, d_model: int) -> int:
|
13 |
+
# set hidden dimesion to nearest multiple of 256 after expansion ratio
|
14 |
+
return int(((expansion_ratio * d_model) + 255) // 256 * 256)
|
15 |
+
|
16 |
+
|
17 |
+
class SwiGLU(nn.Module):
|
18 |
+
"""
|
19 |
+
SwiGLU activation function as an nn.Module, allowing it to be used within nn.Sequential.
|
20 |
+
This module splits the input tensor along the last dimension and applies the SiLU (Swish)
|
21 |
+
activation function to the first half, then multiplies it by the second half.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self):
|
25 |
+
super(SwiGLU, self).__init__()
|
26 |
+
|
27 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
28 |
+
x1, x2 = x.chunk(2, dim=-1)
|
29 |
+
return F.silu(x1) * x2
|
30 |
+
|
31 |
+
|
32 |
+
def swiglu_ln_ffn(d_model: int, expansion_ratio: float, bias: bool):
|
33 |
+
return nn.Sequential(
|
34 |
+
nn.LayerNorm(d_model),
|
35 |
+
nn.Linear(
|
36 |
+
d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=bias
|
37 |
+
),
|
38 |
+
SwiGLU(),
|
39 |
+
nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=bias),
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
def gelu_ln_ffn(d_model: int, expansion_ratio: float, bias: bool):
|
44 |
+
hidden_dim = int(expansion_ratio * d_model)
|
45 |
+
return nn.Sequential(
|
46 |
+
nn.LayerNorm(d_model),
|
47 |
+
nn.Linear(d_model, hidden_dim, bias=bias),
|
48 |
+
nn.GELU(),
|
49 |
+
nn.Linear(hidden_dim, d_model, bias=bias),
|
50 |
+
)
|
51 |
+
|
52 |
+
|
53 |
+
class UnifiedTransformerBlock(nn.Module):
|
54 |
+
"""
|
55 |
+
A unified transformer block that can optionally incorporate geometric attention.
|
56 |
+
|
57 |
+
This class defines a transformer block that can be configured to use geometric attention
|
58 |
+
alongside the standard multi-head attention mechanism. It is designed to be a flexible
|
59 |
+
component of transformer-based models, allowing for the integration of geometric reasoning.
|
60 |
+
|
61 |
+
Parameters
|
62 |
+
----------
|
63 |
+
d_model : int
|
64 |
+
The dimensionality of the input and output features of the transformer block.
|
65 |
+
n_heads : int
|
66 |
+
The number of attention heads in the multi-head attention mechanism.
|
67 |
+
n_layers : int
|
68 |
+
The number of layers in the transformer block.
|
69 |
+
use_geom_attn : bool, optional
|
70 |
+
Whether to use geometric attention in addition to the standard multi-head attention. Defaults to False.
|
71 |
+
v_heads : int, optional
|
72 |
+
The number of heads to use for the geometric attention mechanism, if enabled. Must be specified if `use_geom_attn` is True.
|
73 |
+
"""
|
74 |
+
|
75 |
+
def __init__(
|
76 |
+
self,
|
77 |
+
d_model: int,
|
78 |
+
n_heads: int,
|
79 |
+
use_geom_attn: bool = False,
|
80 |
+
use_plain_attn: bool = True,
|
81 |
+
v_heads: int | None = None,
|
82 |
+
bias: bool = False,
|
83 |
+
expansion_ratio: float = 4.0,
|
84 |
+
residue_scaling_factor: float = 1,
|
85 |
+
mask_and_zero_frameless: bool = False,
|
86 |
+
qk_layernorm: bool = True,
|
87 |
+
ffn_type: str = "swiglu", # swiglu | gelu
|
88 |
+
):
|
89 |
+
super().__init__()
|
90 |
+
self.use_plain_attn = use_plain_attn
|
91 |
+
if self.use_plain_attn:
|
92 |
+
self.attn = MultiHeadAttention(
|
93 |
+
d_model, n_heads, bias, qk_layernorm=qk_layernorm
|
94 |
+
)
|
95 |
+
self.use_geom_attn = use_geom_attn
|
96 |
+
if self.use_geom_attn:
|
97 |
+
if v_heads is None:
|
98 |
+
raise ValueError("v_heads must be specified when use_geom_attn is True")
|
99 |
+
self.geom_attn = GeometricReasoningOriginalImpl(
|
100 |
+
c_s=d_model,
|
101 |
+
v_heads=v_heads,
|
102 |
+
bias=bias,
|
103 |
+
mask_and_zero_frameless=mask_and_zero_frameless,
|
104 |
+
)
|
105 |
+
if ffn_type == "swiglu":
|
106 |
+
self.ffn = swiglu_ln_ffn(d_model, expansion_ratio, bias)
|
107 |
+
elif ffn_type == "gelu":
|
108 |
+
self.ffn = gelu_ln_ffn(d_model, expansion_ratio, bias)
|
109 |
+
else:
|
110 |
+
raise ValueError(f"Unknown ffn_type: {ffn_type}")
|
111 |
+
self.scaling_factor = residue_scaling_factor
|
112 |
+
|
113 |
+
def forward(
|
114 |
+
self,
|
115 |
+
x: torch.Tensor,
|
116 |
+
sequence_id: torch.Tensor,
|
117 |
+
frames: Affine3D,
|
118 |
+
frames_mask: torch.Tensor,
|
119 |
+
chain_id: torch.Tensor,
|
120 |
+
) -> torch.Tensor:
|
121 |
+
"""
|
122 |
+
Forward pass for the UnifiedTransformerBlock.
|
123 |
+
|
124 |
+
Parameters
|
125 |
+
----------
|
126 |
+
x : torch.Tensor[float]
|
127 |
+
Input tensor to the transformer block, typically the output from the previous layer.
|
128 |
+
sequence_id : torch.Tensor[int]
|
129 |
+
Tensor containing sequence IDs for each element in the batch, used for attention masking.
|
130 |
+
frames : Affine3D
|
131 |
+
Affine3D containing geometric frame information for geometric attention.
|
132 |
+
frames_mask : torch.Tensor[bool]
|
133 |
+
Boolean mask tensor indicating valid frames for geometric attention.
|
134 |
+
chain_id : torch.Tensor[int]
|
135 |
+
Tensor containing chain IDs for each element, used for attention masking in geometric attention.
|
136 |
+
|
137 |
+
Returns
|
138 |
+
-------
|
139 |
+
torch.Tensor[float]
|
140 |
+
The output tensor after applying the transformer block operations.
|
141 |
+
"""
|
142 |
+
if self.use_plain_attn:
|
143 |
+
r1 = self.attn(x, sequence_id)
|
144 |
+
x = x + r1 / self.scaling_factor
|
145 |
+
|
146 |
+
if self.use_geom_attn:
|
147 |
+
r2 = self.geom_attn(x, frames, frames_mask, sequence_id, chain_id)
|
148 |
+
x = x + r2 / self.scaling_factor
|
149 |
+
|
150 |
+
r3 = self.ffn(x) / self.scaling_factor
|
151 |
+
x = x + r3
|
152 |
+
|
153 |
+
return x
|
Dyna-1/esm/layers/codebook.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.distributed as dist
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class EMACodebook(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
n_codes,
|
12 |
+
embedding_dim,
|
13 |
+
no_random_restart=True,
|
14 |
+
restart_thres=1.0,
|
15 |
+
ema_decay=0.99,
|
16 |
+
):
|
17 |
+
super().__init__()
|
18 |
+
self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim))
|
19 |
+
self.register_buffer("N", torch.zeros(n_codes))
|
20 |
+
self.register_buffer("z_avg", self.embeddings.data.clone())
|
21 |
+
|
22 |
+
self.n_codes = n_codes
|
23 |
+
self.embedding_dim = embedding_dim
|
24 |
+
self._need_init = True
|
25 |
+
self.no_random_restart = no_random_restart
|
26 |
+
self.restart_thres = restart_thres
|
27 |
+
self.freeze_codebook = False
|
28 |
+
self.ema_decay = ema_decay
|
29 |
+
|
30 |
+
def reset_parameters(self):
|
31 |
+
# For meta init
|
32 |
+
pass
|
33 |
+
|
34 |
+
def _tile(self, x):
|
35 |
+
d, ew = x.shape
|
36 |
+
if d < self.n_codes:
|
37 |
+
n_repeats = (self.n_codes + d - 1) // d
|
38 |
+
std = 0.01 / np.sqrt(ew)
|
39 |
+
x = x.repeat(n_repeats, 1)
|
40 |
+
x = x + torch.randn_like(x) * std
|
41 |
+
return x
|
42 |
+
|
43 |
+
def _init_embeddings(self, z):
|
44 |
+
# z: [b, t, c]
|
45 |
+
self._need_init = False
|
46 |
+
flat_inputs = z.view(-1, self.embedding_dim)
|
47 |
+
y = self._tile(flat_inputs)
|
48 |
+
|
49 |
+
y.shape[0]
|
50 |
+
_k_rand = y[torch.randperm(y.shape[0])][: self.n_codes]
|
51 |
+
if dist.is_initialized():
|
52 |
+
dist.broadcast(_k_rand, 0)
|
53 |
+
self.embeddings.data.copy_(_k_rand)
|
54 |
+
self.z_avg.data.copy_(_k_rand)
|
55 |
+
self.N.data.copy_(torch.ones(self.n_codes))
|
56 |
+
|
57 |
+
def forward(self, z):
|
58 |
+
# z: [b, t, c]
|
59 |
+
if self._need_init and self.training and not self.freeze_codebook:
|
60 |
+
self._init_embeddings(z)
|
61 |
+
# z is of shape [batch_size, sequence length, channels]
|
62 |
+
flat_inputs = z.view(-1, self.embedding_dim)
|
63 |
+
distances = (
|
64 |
+
(flat_inputs**2).sum(dim=1, keepdim=True)
|
65 |
+
- 2 * flat_inputs @ self.embeddings.t()
|
66 |
+
+ (self.embeddings.t() ** 2).sum(dim=0, keepdim=True)
|
67 |
+
) # [bt, c]
|
68 |
+
|
69 |
+
encoding_indices = torch.argmin(distances, dim=1)
|
70 |
+
encoding_indices = encoding_indices.view(*z.shape[:2]) # [b, t, ncode]
|
71 |
+
|
72 |
+
embeddings = F.embedding(encoding_indices, self.embeddings) # [b, t, c]
|
73 |
+
|
74 |
+
commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach())
|
75 |
+
|
76 |
+
# EMA codebook update
|
77 |
+
if self.training and not self.freeze_codebook:
|
78 |
+
assert False, "Not implemented"
|
79 |
+
embeddings_st = (embeddings - z).detach() + z
|
80 |
+
|
81 |
+
return embeddings_st, encoding_indices, commitment_loss
|
82 |
+
|
83 |
+
def dictionary_lookup(self, encodings):
|
84 |
+
embeddings = F.embedding(encodings, self.embeddings)
|
85 |
+
return embeddings
|
86 |
+
|
87 |
+
def soft_codebook_lookup(self, weights: torch.Tensor) -> torch.Tensor:
|
88 |
+
return weights @ self.embeddings
|
Dyna-1/esm/layers/ffn.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import Tensor
|
4 |
+
|
5 |
+
# NOT CURRENTLY USED
|
6 |
+
|
7 |
+
|
8 |
+
class SwiGLU(nn.Module):
|
9 |
+
def __init__(self) -> None:
|
10 |
+
super().__init__()
|
11 |
+
|
12 |
+
def forward(self, x: Tensor) -> Tensor:
|
13 |
+
x1, x2 = x.chunk(2, dim=-1)
|
14 |
+
hidden = F.silu(x1) * x2
|
15 |
+
return hidden
|
16 |
+
|
17 |
+
|
18 |
+
class FFN(nn.Module):
|
19 |
+
def __init__(self, in_proj, activation, out_proj) -> None:
|
20 |
+
super().__init__()
|
21 |
+
self.in_proj = in_proj
|
22 |
+
self.activation = activation
|
23 |
+
self.out_proj = out_proj
|
24 |
+
|
25 |
+
def forward(self, x: Tensor) -> Tensor:
|
26 |
+
x = self.in_proj(x)
|
27 |
+
x = self.activation(x)
|
28 |
+
x = self.out_proj(x)
|
29 |
+
return x
|
Dyna-1/esm/layers/geom_attention.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import sqrt
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from einops import rearrange
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
|
9 |
+
class GeometricReasoningOriginalImpl(nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
c_s: int,
|
13 |
+
v_heads: int,
|
14 |
+
num_vector_messages: int = 1,
|
15 |
+
mask_and_zero_frameless: bool = True,
|
16 |
+
divide_residual_by_depth: bool = False,
|
17 |
+
bias: bool = False,
|
18 |
+
):
|
19 |
+
"""Approximate implementation:
|
20 |
+
|
21 |
+
ATTN(A, v) := (softmax_j A_ij) v_j
|
22 |
+
make_rot_vectors(x) := R(i->g) Linear(x).reshape(..., 3)
|
23 |
+
make_vectors(x) := T(i->g) Linear(x).reshape(..., 3)
|
24 |
+
|
25 |
+
v <- make_rot_vectors(x)
|
26 |
+
q_dir, k_dir <- make_rot_vectors(x)
|
27 |
+
q_dist, k_dist <- make_vectors(x)
|
28 |
+
|
29 |
+
A_ij <- dot(q_dir_i, k_dir_j) -||q_dist_i - k_dist_j||^2
|
30 |
+
x <- x + Linear(T(g->i) ATTN(A, v))
|
31 |
+
"""
|
32 |
+
super().__init__()
|
33 |
+
self.c_s = c_s
|
34 |
+
self.v_heads = v_heads
|
35 |
+
self.num_vector_messages = num_vector_messages
|
36 |
+
self.mask_and_zero_frameless = mask_and_zero_frameless
|
37 |
+
|
38 |
+
self.s_norm = nn.LayerNorm(c_s, bias=bias)
|
39 |
+
dim_proj = (
|
40 |
+
4 * self.v_heads * 3 + self.v_heads * 3 * self.num_vector_messages
|
41 |
+
) # 2 x (q, k) * number of heads * (x, y, z) + number of heads * number of vector messages * (x, y, z)
|
42 |
+
self.proj = nn.Linear(c_s, dim_proj, bias=bias)
|
43 |
+
channels_out = self.v_heads * 3 * self.num_vector_messages
|
44 |
+
self.out_proj = nn.Linear(channels_out, c_s, bias=bias)
|
45 |
+
|
46 |
+
# The basic idea is for some attention heads to pay more or less attention to rotation versus distance,
|
47 |
+
# as well as to control the sharpness of the softmax (i.e., should this head only attend to those residues
|
48 |
+
# very nearby or should there be shallower dropoff in attention weight?)
|
49 |
+
self.distance_scale_per_head = nn.Parameter(torch.zeros((self.v_heads)))
|
50 |
+
self.rotation_scale_per_head = nn.Parameter(torch.zeros((self.v_heads)))
|
51 |
+
|
52 |
+
def forward(self, s, affine, affine_mask, sequence_id, chain_id):
|
53 |
+
if sequence_id is None:
|
54 |
+
sequence_id = torch.zeros_like(s[..., 0], dtype=torch.int64)
|
55 |
+
attn_bias = sequence_id.unsqueeze(-1) == sequence_id.unsqueeze(-2)
|
56 |
+
attn_bias = attn_bias.unsqueeze(1).float()
|
57 |
+
attn_bias = attn_bias.masked_fill(
|
58 |
+
~affine_mask[:, None, None, :], torch.finfo(attn_bias.dtype).min
|
59 |
+
)
|
60 |
+
chain_id_mask = chain_id.unsqueeze(1) != chain_id.unsqueeze(2)
|
61 |
+
attn_bias = attn_bias.masked_fill(
|
62 |
+
chain_id_mask.unsqueeze(1), torch.finfo(s.dtype).min
|
63 |
+
)
|
64 |
+
|
65 |
+
ns = self.s_norm(s)
|
66 |
+
vec_rot, vec_dist = self.proj(ns).split(
|
67 |
+
[
|
68 |
+
self.v_heads * 2 * 3 + self.v_heads * 3 * self.num_vector_messages,
|
69 |
+
self.v_heads * 2 * 3,
|
70 |
+
],
|
71 |
+
dim=-1,
|
72 |
+
)
|
73 |
+
|
74 |
+
# Rotate the queries and keys for the rotation term. We also rotate the values.
|
75 |
+
# NOTE(zeming, thayes): Values are only rotated, not translated. We may wish to change
|
76 |
+
# this in the future.
|
77 |
+
query_rot, key_rot, value = (
|
78 |
+
affine.rot[..., None]
|
79 |
+
.apply(rearrange(vec_rot, "... (h c) -> ... h c", c=3))
|
80 |
+
.split(
|
81 |
+
[self.v_heads, self.v_heads, self.v_heads * self.num_vector_messages],
|
82 |
+
dim=-2,
|
83 |
+
)
|
84 |
+
)
|
85 |
+
|
86 |
+
# Rotate and translate the queries and keys for the distance term
|
87 |
+
# NOTE(thayes): a simple speedup would be to apply all rotations together, then
|
88 |
+
# separately apply the translations.
|
89 |
+
query_dist, key_dist = (
|
90 |
+
affine[..., None]
|
91 |
+
.apply(rearrange(vec_dist, "... (h c) -> ... h c", c=3))
|
92 |
+
.chunk(2, dim=-2)
|
93 |
+
)
|
94 |
+
|
95 |
+
query_dist = rearrange(query_dist, "b s h d -> b h s 1 d")
|
96 |
+
key_dist = rearrange(key_dist, "b s h d -> b h 1 s d")
|
97 |
+
query_rot = rearrange(query_rot, "b s h d -> b h s d")
|
98 |
+
key_rot = rearrange(key_rot, "b s h d -> b h d s")
|
99 |
+
value = rearrange(
|
100 |
+
value, "b s (h m) d -> b h s (m d)", m=self.num_vector_messages
|
101 |
+
)
|
102 |
+
|
103 |
+
distance_term = (query_dist - key_dist).norm(dim=-1) / sqrt(3)
|
104 |
+
rotation_term = query_rot.matmul(key_rot) / sqrt(3)
|
105 |
+
distance_term_weight = rearrange(
|
106 |
+
F.softplus(self.distance_scale_per_head), "h -> h 1 1"
|
107 |
+
)
|
108 |
+
rotation_term_weight = rearrange(
|
109 |
+
F.softplus(self.rotation_scale_per_head), "h -> h 1 1"
|
110 |
+
)
|
111 |
+
|
112 |
+
attn_weight = (
|
113 |
+
rotation_term * rotation_term_weight - distance_term * distance_term_weight
|
114 |
+
)
|
115 |
+
|
116 |
+
if attn_bias is not None:
|
117 |
+
# we can re-use the attention bias from the transformer layers
|
118 |
+
# NOTE(thayes): This attention bias is expected to handle two things:
|
119 |
+
# 1. Masking attention on padding tokens
|
120 |
+
# 2. Masking cross sequence attention in the case of bin packing
|
121 |
+
s_q = attn_weight.size(2)
|
122 |
+
s_k = attn_weight.size(3)
|
123 |
+
_s_q = max(0, attn_bias.size(2) - s_q)
|
124 |
+
_s_k = max(0, attn_bias.size(3) - s_k)
|
125 |
+
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
126 |
+
attn_weight = attn_weight + attn_bias
|
127 |
+
|
128 |
+
attn_weight = torch.softmax(attn_weight, dim=-1)
|
129 |
+
|
130 |
+
attn_out = attn_weight.matmul(value)
|
131 |
+
|
132 |
+
attn_out = (
|
133 |
+
affine.rot[..., None]
|
134 |
+
.invert()
|
135 |
+
.apply(
|
136 |
+
rearrange(
|
137 |
+
attn_out, "b h s (m d) -> b s (h m) d", m=self.num_vector_messages
|
138 |
+
)
|
139 |
+
)
|
140 |
+
)
|
141 |
+
|
142 |
+
attn_out = rearrange(
|
143 |
+
attn_out, "b s (h m) d -> b s (h m d)", m=self.num_vector_messages
|
144 |
+
)
|
145 |
+
if self.mask_and_zero_frameless:
|
146 |
+
attn_out = attn_out.masked_fill(~affine_mask[..., None], 0.0)
|
147 |
+
s = self.out_proj(attn_out)
|
148 |
+
|
149 |
+
return s
|
Dyna-1/esm/layers/regression_head.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
def RegressionHead(
|
5 |
+
d_model: int, output_dim: int, hidden_dim: int | None = None
|
6 |
+
) -> nn.Module:
|
7 |
+
"""Single-hidden layer MLP for supervised output.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
d_model: input dimension
|
11 |
+
output_dim: dimensionality of the output.
|
12 |
+
hidden_dim: optional dimension of hidden layer, defaults to d_model.
|
13 |
+
Returns:
|
14 |
+
output MLP module.
|
15 |
+
"""
|
16 |
+
hidden_dim = hidden_dim if hidden_dim is not None else d_model
|
17 |
+
return nn.Sequential(
|
18 |
+
nn.Linear(d_model, hidden_dim),
|
19 |
+
nn.GELU(),
|
20 |
+
nn.LayerNorm(hidden_dim),
|
21 |
+
nn.Linear(hidden_dim, output_dim),
|
22 |
+
)
|
Dyna-1/esm/layers/rotary.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
4 |
+
# and OPT implementations in this library. It has been modified from its
|
5 |
+
# original forms to accommodate minor architectural differences compared
|
6 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
7 |
+
#
|
8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
9 |
+
# you may not use this file except in compliance with the License.
|
10 |
+
# You may obtain a copy of the License at
|
11 |
+
#
|
12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
13 |
+
#
|
14 |
+
# Unless required by applicable law or agreed to in writing, software
|
15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
17 |
+
# See the License for the specific language governing permissions and
|
18 |
+
# limitations under the License.
|
19 |
+
# NOTE: this implementation is from LLaMA 2:
|
20 |
+
# https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/08639a72e17836184096ae6a7e2766f2a34c3e36/modeling_flash_llama.py#L114
|
21 |
+
# Flash attention rotary implementation can be installed like so: `pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary`
|
22 |
+
|
23 |
+
from typing import Tuple
|
24 |
+
|
25 |
+
import torch
|
26 |
+
from einops import rearrange, repeat
|
27 |
+
|
28 |
+
|
29 |
+
def rotate_half(x, interleaved=False):
|
30 |
+
if not interleaved:
|
31 |
+
x1, x2 = x.chunk(2, dim=-1)
|
32 |
+
return torch.cat((-x2, x1), dim=-1)
|
33 |
+
else:
|
34 |
+
x1, x2 = x[..., ::2], x[..., 1::2]
|
35 |
+
return rearrange(
|
36 |
+
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
def apply_rotary_emb_torch(x, cos, sin, interleaved=False, _inplace=False):
|
41 |
+
"""
|
42 |
+
x: (batch_size, seqlen, nheads, headdim)
|
43 |
+
cos, sin: (seqlen, rotary_dim / 2)
|
44 |
+
"""
|
45 |
+
ro_dim = cos.shape[-1] * 2
|
46 |
+
assert ro_dim <= x.shape[-1]
|
47 |
+
seqlen = x.size(1)
|
48 |
+
cos = cos[:seqlen]
|
49 |
+
sin = sin[:seqlen]
|
50 |
+
cos = repeat(cos, "s d -> s 1 (2 d)")
|
51 |
+
sin = repeat(sin, "s d -> s 1 (2 d)")
|
52 |
+
return torch.cat(
|
53 |
+
[
|
54 |
+
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
55 |
+
x[..., ro_dim:],
|
56 |
+
],
|
57 |
+
dim=-1,
|
58 |
+
)
|
59 |
+
|
60 |
+
|
61 |
+
class RotaryEmbedding(torch.nn.Module):
|
62 |
+
"""
|
63 |
+
The rotary position embeddings from RoFormer_ (Su et. al).
|
64 |
+
A crucial insight from the method is that the query and keys are
|
65 |
+
transformed by rotation matrices which depend on the relative positions.
|
66 |
+
Other implementations are available in the Rotary Transformer repo_ and in
|
67 |
+
GPT-NeoX_, GPT-NeoX was an inspiration
|
68 |
+
.. _RoFormer: https://arxiv.org/abs/2104.09864
|
69 |
+
.. _repo: https://github.com/ZhuiyiTechnology/roformer
|
70 |
+
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
|
71 |
+
If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
|
72 |
+
A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
|
73 |
+
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
|
74 |
+
"""
|
75 |
+
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
dim: int,
|
79 |
+
base=10000.0,
|
80 |
+
interleaved=False,
|
81 |
+
scale_base=None,
|
82 |
+
scaling_factor=1.0,
|
83 |
+
pos_idx_in_fp32=True,
|
84 |
+
device=None,
|
85 |
+
):
|
86 |
+
"""
|
87 |
+
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
88 |
+
of 1st half and 2nd half (GPT-NeoX style).
|
89 |
+
pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
|
90 |
+
otherwise they might be in lower precision.
|
91 |
+
This option was added because previously (before 2023-07-02), when we construct
|
92 |
+
the position indices, we use the dtype of self.inv_freq. In most cases this would
|
93 |
+
be fp32, but if the model is trained in pure bf16 (not mixed precision), then
|
94 |
+
self.inv_freq would be bf16, and the position indices are also in bf16.
|
95 |
+
Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
|
96 |
+
embeddings for some positions will coincide.
|
97 |
+
To maintain compatibility with models previously trained in pure bf16,
|
98 |
+
we add this option.
|
99 |
+
scaling_factor: RotaryEmbedding extended with linear scaling.
|
100 |
+
"""
|
101 |
+
super().__init__()
|
102 |
+
self.dim = dim
|
103 |
+
self.base = float(base)
|
104 |
+
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
105 |
+
# Generate and save the inverse frequency buffer (non trainable)
|
106 |
+
self.interleaved = interleaved
|
107 |
+
self.scale_base = scale_base
|
108 |
+
self.scaling_factor = scaling_factor
|
109 |
+
self.device = device
|
110 |
+
|
111 |
+
self._seq_len_cached = 0
|
112 |
+
self._cos_cached = None
|
113 |
+
self._sin_cached = None
|
114 |
+
self._cos_k_cached = None
|
115 |
+
self._sin_k_cached = None
|
116 |
+
self.reset_parameters()
|
117 |
+
|
118 |
+
def reset_parameters(self):
|
119 |
+
inv_freq = self._compute_inv_freq(self.device)
|
120 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
121 |
+
arange = torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32)
|
122 |
+
scale = (
|
123 |
+
(arange + 0.4 * self.dim) / (1.4 * self.dim)
|
124 |
+
if self.scale_base is not None
|
125 |
+
else None
|
126 |
+
)
|
127 |
+
self.register_buffer("scale", scale)
|
128 |
+
|
129 |
+
def _compute_inv_freq(self, device=None):
|
130 |
+
return 1 / (
|
131 |
+
self.base
|
132 |
+
** (
|
133 |
+
torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
|
134 |
+
/ self.dim
|
135 |
+
)
|
136 |
+
)
|
137 |
+
|
138 |
+
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
139 |
+
# Reset the tables if the sequence length has changed,
|
140 |
+
# if we're on a new device (possibly due to tracing for instance),
|
141 |
+
# or if we're switching from inference mode to training
|
142 |
+
if (
|
143 |
+
seqlen > self._seq_len_cached
|
144 |
+
or self._cos_cached is None
|
145 |
+
or self._cos_cached.device != device
|
146 |
+
or self._cos_cached.dtype != dtype
|
147 |
+
or (self.training and self._cos_cached.is_inference())
|
148 |
+
):
|
149 |
+
self._seq_len_cached = seqlen
|
150 |
+
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
|
151 |
+
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
|
152 |
+
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
|
153 |
+
if self.pos_idx_in_fp32:
|
154 |
+
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
155 |
+
t /= self.scaling_factor
|
156 |
+
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
|
157 |
+
# will be large. Having it in bf16 will lose a lot of precision and cause the
|
158 |
+
# cos & sin output to change significantly.
|
159 |
+
# We want to recompute self.inv_freq if it was not loaded in fp32
|
160 |
+
if self.inv_freq.dtype != torch.float32:
|
161 |
+
inv_freq = self.inv_freq.to(torch.float32)
|
162 |
+
else:
|
163 |
+
inv_freq = self.inv_freq
|
164 |
+
else:
|
165 |
+
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
166 |
+
t /= self.scaling_factor
|
167 |
+
inv_freq = self.inv_freq
|
168 |
+
# Don't do einsum, it converts fp32 to fp16 under AMP
|
169 |
+
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
170 |
+
freqs = torch.outer(t, inv_freq)
|
171 |
+
|
172 |
+
if self.scale is None:
|
173 |
+
self._cos_cached = torch.cos(freqs).to(dtype)
|
174 |
+
self._sin_cached = torch.sin(freqs).to(dtype)
|
175 |
+
else:
|
176 |
+
power = (
|
177 |
+
torch.arange(
|
178 |
+
seqlen, dtype=self.scale.dtype, device=self.scale.device
|
179 |
+
)
|
180 |
+
- seqlen // 2
|
181 |
+
) / self.scale_base
|
182 |
+
scale = self.scale.to(device=power.device) ** power.unsqueeze(-1)
|
183 |
+
# We want the multiplication by scale to happen in fp32
|
184 |
+
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
185 |
+
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
186 |
+
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
187 |
+
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
188 |
+
|
189 |
+
def forward(
|
190 |
+
self, q: torch.Tensor, k: torch.Tensor, seqlen_offset: int = 0
|
191 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
192 |
+
"""
|
193 |
+
q: (batch, seqlen, nheads, headdim)
|
194 |
+
k: (batch, seqlen, nheads, headdim)
|
195 |
+
seqlen_offset: can be used in generation where the qkv being passed in is only the last
|
196 |
+
token in the batch.
|
197 |
+
"""
|
198 |
+
self._update_cos_sin_cache(
|
199 |
+
q.shape[1] + seqlen_offset, device=q.device, dtype=q.dtype
|
200 |
+
)
|
201 |
+
assert self._cos_cached is not None
|
202 |
+
assert self._sin_cached is not None
|
203 |
+
if self.scale is None:
|
204 |
+
return (
|
205 |
+
apply_rotary_emb_torch(
|
206 |
+
q,
|
207 |
+
self._cos_cached[seqlen_offset:],
|
208 |
+
self._sin_cached[seqlen_offset:],
|
209 |
+
self.interleaved,
|
210 |
+
True, # inplace=True
|
211 |
+
),
|
212 |
+
apply_rotary_emb_torch(
|
213 |
+
k,
|
214 |
+
self._cos_cached[seqlen_offset:],
|
215 |
+
self._sin_cached[seqlen_offset:],
|
216 |
+
self.interleaved,
|
217 |
+
True, # inplace=True
|
218 |
+
),
|
219 |
+
) # type: ignore
|
220 |
+
else:
|
221 |
+
assert False
|
Dyna-1/esm/layers/structure_proj.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from esm.utils.constants.physics import BB_COORDINATES
|
5 |
+
from esm.utils.structure.affine3d import (
|
6 |
+
Affine3D,
|
7 |
+
RotationMatrix,
|
8 |
+
)
|
9 |
+
|
10 |
+
|
11 |
+
class Dim6RotStructureHead(nn.Module):
|
12 |
+
# Normally, AF2 uses quaternions to specify rotations. There's some evidence that
|
13 |
+
# other representations are more well behaved - the best one according to
|
14 |
+
# https://openaccess.thecvf.com/content_CVPR_2019/papers/Zhou_On_the_Continuity_of_Rotation_Representations_in_Neural_Networks_CVPR_2019_paper.pdf
|
15 |
+
# is using graham schmidt on 2 vectors, which is implemented here.
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
input_dim: int,
|
19 |
+
trans_scale_factor: float = 10,
|
20 |
+
norm_type: str = "layernorm",
|
21 |
+
activation_fn: str = "esm_gelu",
|
22 |
+
predict_torsion_angles: bool = True,
|
23 |
+
):
|
24 |
+
super().__init__()
|
25 |
+
self.ffn1 = nn.Linear(input_dim, input_dim)
|
26 |
+
self.activation_fn = nn.GELU()
|
27 |
+
self.norm = nn.LayerNorm(input_dim)
|
28 |
+
self.proj = nn.Linear(input_dim, 9 + 7 * 2)
|
29 |
+
self.trans_scale_factor = trans_scale_factor
|
30 |
+
self.predict_torsion_angles = predict_torsion_angles
|
31 |
+
self.bb_local_coords = torch.tensor(BB_COORDINATES).float()
|
32 |
+
|
33 |
+
def forward(self, x, affine, affine_mask, **kwargs):
|
34 |
+
if affine is None:
|
35 |
+
rigids = Affine3D.identity(
|
36 |
+
x.shape[:-1],
|
37 |
+
dtype=x.dtype,
|
38 |
+
device=x.device,
|
39 |
+
requires_grad=self.training,
|
40 |
+
rotation_type=RotationMatrix,
|
41 |
+
)
|
42 |
+
else:
|
43 |
+
rigids = affine
|
44 |
+
|
45 |
+
# [*, N]
|
46 |
+
x = self.ffn1(x)
|
47 |
+
x = self.activation_fn(x)
|
48 |
+
x = self.norm(x)
|
49 |
+
trans, x, y, angles = self.proj(x).split([3, 3, 3, 7 * 2], dim=-1)
|
50 |
+
trans = trans * self.trans_scale_factor
|
51 |
+
x = x / (x.norm(dim=-1, keepdim=True) + 1e-5)
|
52 |
+
y = y / (y.norm(dim=-1, keepdim=True) + 1e-5)
|
53 |
+
update = Affine3D.from_graham_schmidt(x + trans, trans, y + trans)
|
54 |
+
rigids = rigids.compose(update.mask(affine_mask))
|
55 |
+
affine = rigids.tensor
|
56 |
+
|
57 |
+
# We approximate the positions of the backbone atoms in the global frame by applying the rigid
|
58 |
+
# transformation to the mean of the backbone atoms in the local frame.
|
59 |
+
all_bb_coords_local = (
|
60 |
+
self.bb_local_coords[None, None, :, :]
|
61 |
+
.expand(*x.shape[:-1], 3, 3)
|
62 |
+
.to(x.device)
|
63 |
+
)
|
64 |
+
pred_xyz = rigids[..., None].apply(all_bb_coords_local)
|
65 |
+
|
66 |
+
return affine, pred_xyz
|
Dyna-1/esm/layers/transformer_stack.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from esm.layers.blocks import UnifiedTransformerBlock
|
7 |
+
from esm.utils.structure.affine3d import Affine3D
|
8 |
+
|
9 |
+
|
10 |
+
class TransformerStack(nn.Module):
|
11 |
+
"""
|
12 |
+
A stack of transformer blocks used in the ESM-3 model. Each block is a UnifiedTransformerBlock,
|
13 |
+
which can either be geometric attention or standard multi-head attention.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
d_model (int): The dimensionality of the input and output feature vectors.
|
17 |
+
n_heads (int): The number of attention heads.
|
18 |
+
v_heads (int): The number of voting heads.
|
19 |
+
n_layers (int): The number of transformer blocks in the stack.
|
20 |
+
n_layers_geom (int, optional): The number of transformer blocks that use geometric attention.
|
21 |
+
scale_residue (bool, optional): Whether to scale the residue connections in each transformer block.
|
22 |
+
mask_and_zero_frameless (bool, optional): Whether to mask and zero frameless positions in the input.
|
23 |
+
Only applies in the geometric attention blocks, which is conditioned on the structure
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
d_model: int,
|
29 |
+
n_heads: int,
|
30 |
+
v_heads: int | None,
|
31 |
+
n_layers: int,
|
32 |
+
n_layers_geom: int = 1,
|
33 |
+
scale_residue: bool = True,
|
34 |
+
mask_and_zero_frameless: bool = False,
|
35 |
+
bias: bool = False,
|
36 |
+
qk_layernorm: bool = True,
|
37 |
+
ffn_type: str = "swiglu", # swiglu | gelu
|
38 |
+
expansion_ratio: float = 8 / 3,
|
39 |
+
):
|
40 |
+
super().__init__()
|
41 |
+
self.blocks = nn.ModuleList(
|
42 |
+
[
|
43 |
+
UnifiedTransformerBlock(
|
44 |
+
d_model,
|
45 |
+
n_heads,
|
46 |
+
v_heads=v_heads,
|
47 |
+
use_geom_attn=i < n_layers_geom,
|
48 |
+
residue_scaling_factor=(
|
49 |
+
math.sqrt(n_layers / 36) if scale_residue else 1.0
|
50 |
+
),
|
51 |
+
expansion_ratio=expansion_ratio,
|
52 |
+
mask_and_zero_frameless=mask_and_zero_frameless,
|
53 |
+
bias=bias,
|
54 |
+
qk_layernorm=qk_layernorm,
|
55 |
+
ffn_type=ffn_type,
|
56 |
+
)
|
57 |
+
for i in range(n_layers)
|
58 |
+
]
|
59 |
+
)
|
60 |
+
self.norm = nn.LayerNorm(d_model, bias=False)
|
61 |
+
|
62 |
+
def forward(
|
63 |
+
self,
|
64 |
+
x: torch.Tensor,
|
65 |
+
sequence_id: torch.Tensor | None = None,
|
66 |
+
affine: Affine3D | None = None,
|
67 |
+
affine_mask: torch.Tensor | None = None,
|
68 |
+
chain_id: torch.Tensor | None = None,
|
69 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
70 |
+
"""
|
71 |
+
Forward pass of the TransformerStack.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
x (torch.Tensor): The input tensor of shape (batch_size, sequence_length, d_model).
|
75 |
+
sequence_id (torch.Tensor): The sequence ID tensor of shape (batch_size, sequence_length).
|
76 |
+
affine (Affine3D | None): The affine transformation tensor or None.
|
77 |
+
affine_mask (torch.Tensor | None): The affine mask tensor or None.
|
78 |
+
chain_id (torch.Tensor): The protein chain tensor of shape (batch_size, sequence_length).
|
79 |
+
Only used in geometric attention.
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
post_norm: The output tensor of shape (batch_size, sequence_length, d_model).
|
83 |
+
pre_norm: The embedding of shape (batch_size, sequence_length, d_model).
|
84 |
+
"""
|
85 |
+
*batch_dims, _ = x.shape
|
86 |
+
if chain_id is None:
|
87 |
+
chain_id = torch.ones(size=batch_dims, dtype=torch.int64, device=x.device)
|
88 |
+
hiddens = []
|
89 |
+
for block in self.blocks:
|
90 |
+
x = block(x, sequence_id, affine, affine_mask, chain_id)
|
91 |
+
hiddens.append(x)
|
92 |
+
hiddens = torch.stack(hiddens, dim=0)
|
93 |
+
return self.norm(x), x, hiddens
|
Dyna-1/esm/models/esm3.py
ADDED
@@ -0,0 +1,606 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import contextlib
|
4 |
+
from functools import partial
|
5 |
+
from typing import Callable
|
6 |
+
|
7 |
+
import attr
|
8 |
+
import einops
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from attr import dataclass
|
12 |
+
|
13 |
+
from esm.layers.regression_head import RegressionHead
|
14 |
+
from esm.layers.transformer_stack import TransformerStack
|
15 |
+
from esm.models.function_decoder import FunctionTokenDecoder
|
16 |
+
from esm.models.vqvae import (
|
17 |
+
StructureTokenDecoder,
|
18 |
+
StructureTokenEncoder,
|
19 |
+
)
|
20 |
+
from esm.sdk.api import (
|
21 |
+
ESM3InferenceClient,
|
22 |
+
ESMProtein,
|
23 |
+
ESMProteinTensor,
|
24 |
+
ForwardAndSampleOutput,
|
25 |
+
ForwardTrackData,
|
26 |
+
GenerationConfig,
|
27 |
+
LogitsConfig,
|
28 |
+
LogitsOutput,
|
29 |
+
ProteinType,
|
30 |
+
SamplingConfig,
|
31 |
+
)
|
32 |
+
from esm.tokenization import TokenizerCollectionProtocol
|
33 |
+
from esm.utils import encoding
|
34 |
+
from esm.utils.constants import esm3 as C
|
35 |
+
from esm.utils.constants.models import (
|
36 |
+
ESM3_OPEN_SMALL,
|
37 |
+
normalize_model_name,
|
38 |
+
)
|
39 |
+
from esm.utils.decoding import decode_protein_tensor
|
40 |
+
from esm.utils.generation import (
|
41 |
+
_batch_forward,
|
42 |
+
_sample_per_prompt,
|
43 |
+
_slice_tensor_dataclass,
|
44 |
+
iterative_sampling_raw,
|
45 |
+
iterative_sampling_tokens,
|
46 |
+
)
|
47 |
+
from esm.utils.misc import rbf
|
48 |
+
from esm.utils.sampling import (
|
49 |
+
_BatchedESMProteinTensor,
|
50 |
+
get_default_sampling_config,
|
51 |
+
validate_sampling_config,
|
52 |
+
)
|
53 |
+
from esm.utils.structure.affine3d import (
|
54 |
+
build_affine3d_from_coordinates,
|
55 |
+
)
|
56 |
+
|
57 |
+
|
58 |
+
@dataclass
|
59 |
+
class ESMOutput:
|
60 |
+
sequence_logits: torch.Tensor
|
61 |
+
structure_logits: torch.Tensor
|
62 |
+
secondary_structure_logits: torch.Tensor
|
63 |
+
sasa_logits: torch.Tensor
|
64 |
+
function_logits: torch.Tensor
|
65 |
+
residue_logits: torch.Tensor
|
66 |
+
embeddings: torch.Tensor
|
67 |
+
|
68 |
+
|
69 |
+
class EncodeInputs(nn.Module):
|
70 |
+
"""
|
71 |
+
Module for encoding input features in the ESM-3 model.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
d_model (int): The dimensionality of the model's hidden states.
|
75 |
+
"""
|
76 |
+
|
77 |
+
def __init__(self, d_model: int):
|
78 |
+
super().__init__()
|
79 |
+
|
80 |
+
# Sequence
|
81 |
+
self.sequence_embed = nn.Embedding(64, d_model)
|
82 |
+
# Mandatory information
|
83 |
+
self.plddt_projection = nn.Linear(16, d_model)
|
84 |
+
self.structure_per_res_plddt_projection = nn.Linear(16, d_model)
|
85 |
+
|
86 |
+
# Structure
|
87 |
+
self.structure_tokens_embed = nn.Embedding(4096 + 5, d_model)
|
88 |
+
|
89 |
+
# "Structural" features
|
90 |
+
self.ss8_embed = nn.Embedding(8 + 3, d_model)
|
91 |
+
self.sasa_embed = nn.Embedding(16 + 3, d_model)
|
92 |
+
|
93 |
+
# "Functional" features
|
94 |
+
self.function_embed = nn.ModuleList(
|
95 |
+
[nn.Embedding(260, d_model // 8, padding_idx=0) for _ in range(8)]
|
96 |
+
)
|
97 |
+
|
98 |
+
self.residue_embed = nn.EmbeddingBag(1478, d_model, mode="sum", padding_idx=0)
|
99 |
+
|
100 |
+
def forward(
|
101 |
+
self,
|
102 |
+
sequence_tokens: torch.Tensor,
|
103 |
+
structure_tokens: torch.Tensor,
|
104 |
+
average_plddt: torch.Tensor,
|
105 |
+
per_res_plddt: torch.Tensor,
|
106 |
+
ss8_tokens: torch.Tensor,
|
107 |
+
sasa_tokens: torch.Tensor,
|
108 |
+
function_tokens: torch.Tensor,
|
109 |
+
residue_annotation_tokens: torch.Tensor,
|
110 |
+
) -> torch.Tensor:
|
111 |
+
sequence_embed = self.sequence_embed(sequence_tokens)
|
112 |
+
|
113 |
+
rbf_16_fn = partial(rbf, v_min=0.0, v_max=1.0, n_bins=16)
|
114 |
+
# the `masked_fill(padding_mask.unsqueeze(2), 0)` for the two below is unnecessary
|
115 |
+
# as pad tokens never even interact with the "real" tokens (due to sequence_id)
|
116 |
+
plddt_embed = self.plddt_projection(rbf_16_fn(average_plddt))
|
117 |
+
structure_per_res_plddt = self.structure_per_res_plddt_projection(
|
118 |
+
rbf_16_fn(per_res_plddt)
|
119 |
+
)
|
120 |
+
|
121 |
+
# Structure + "structural features" embeds
|
122 |
+
structure_embed = self.structure_tokens_embed(structure_tokens)
|
123 |
+
ss8_embed = self.ss8_embed(ss8_tokens)
|
124 |
+
sasa_embed = self.sasa_embed(sasa_tokens)
|
125 |
+
|
126 |
+
# "Functional" features embeds
|
127 |
+
function_embed = torch.cat(
|
128 |
+
[
|
129 |
+
embed_fn(funcs)
|
130 |
+
for embed_fn, funcs in zip(
|
131 |
+
self.function_embed, function_tokens.unbind(-1)
|
132 |
+
)
|
133 |
+
],
|
134 |
+
-1,
|
135 |
+
)
|
136 |
+
|
137 |
+
# Residue embeds
|
138 |
+
B, L, N = residue_annotation_tokens.shape
|
139 |
+
residue_embed = self.residue_embed(
|
140 |
+
einops.rearrange(
|
141 |
+
residue_annotation_tokens, "B L N -> (B L) N", B=B, L=L, N=N
|
142 |
+
)
|
143 |
+
)
|
144 |
+
residue_embed = einops.rearrange(residue_embed, "(B L) D -> B L D", B=B, L=L)
|
145 |
+
|
146 |
+
return (
|
147 |
+
sequence_embed
|
148 |
+
+ plddt_embed
|
149 |
+
+ structure_per_res_plddt
|
150 |
+
+ structure_embed
|
151 |
+
+ ss8_embed
|
152 |
+
+ sasa_embed
|
153 |
+
+ function_embed
|
154 |
+
+ residue_embed
|
155 |
+
)
|
156 |
+
|
157 |
+
|
158 |
+
class OutputHeads(nn.Module):
|
159 |
+
def __init__(self, d_model: int):
|
160 |
+
super().__init__()
|
161 |
+
self.sequence_head = RegressionHead(d_model, 64)
|
162 |
+
self.structure_head = RegressionHead(d_model, 4096)
|
163 |
+
self.ss8_head = RegressionHead(d_model, 8 + 3)
|
164 |
+
self.sasa_head = RegressionHead(d_model, 16 + 3)
|
165 |
+
self.function_head = RegressionHead(d_model, 260 * 8)
|
166 |
+
self.residue_head = RegressionHead(d_model, 1478)
|
167 |
+
|
168 |
+
def forward(self, x: torch.Tensor, embed: torch.Tensor) -> ESMOutput:
|
169 |
+
sequence_logits = self.sequence_head(x)
|
170 |
+
structure_logits = self.structure_head(x)
|
171 |
+
secondary_structure_logits = self.ss8_head(x)
|
172 |
+
sasa_logits = self.sasa_head(x)
|
173 |
+
function_logits = self.function_head(x)
|
174 |
+
function_logits = einops.rearrange(function_logits, "... (k v) -> ... k v", k=8)
|
175 |
+
|
176 |
+
residue_logits = self.residue_head(x)
|
177 |
+
|
178 |
+
return ESMOutput(
|
179 |
+
sequence_logits=sequence_logits,
|
180 |
+
structure_logits=structure_logits,
|
181 |
+
secondary_structure_logits=secondary_structure_logits,
|
182 |
+
sasa_logits=sasa_logits,
|
183 |
+
function_logits=function_logits,
|
184 |
+
residue_logits=residue_logits,
|
185 |
+
embeddings=embed,
|
186 |
+
)
|
187 |
+
|
188 |
+
|
189 |
+
class ESM3(nn.Module, ESM3InferenceClient):
|
190 |
+
"""
|
191 |
+
ESM3 model implementation.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
d_model (int): The dimensionality of the input and output feature vectors.
|
195 |
+
n_heads (int): The number of attention heads in the transformer layers.
|
196 |
+
v_heads (int): The number of attention heads in the variational transformer layers.
|
197 |
+
n_layers (int): The number of transformer layers.
|
198 |
+
"""
|
199 |
+
|
200 |
+
def __init__(
|
201 |
+
self,
|
202 |
+
d_model: int,
|
203 |
+
n_heads: int,
|
204 |
+
v_heads: int,
|
205 |
+
n_layers: int,
|
206 |
+
structure_encoder_fn: Callable[[torch.device | str], StructureTokenEncoder],
|
207 |
+
structure_decoder_fn: Callable[[torch.device | str], StructureTokenDecoder],
|
208 |
+
function_decoder_fn: Callable[[torch.device | str], FunctionTokenDecoder],
|
209 |
+
tokenizers: TokenizerCollectionProtocol,
|
210 |
+
):
|
211 |
+
super().__init__()
|
212 |
+
self.encoder = EncodeInputs(d_model)
|
213 |
+
self.transformer = TransformerStack(
|
214 |
+
d_model, n_heads, v_heads, n_layers, mask_and_zero_frameless=True
|
215 |
+
)
|
216 |
+
self.output_heads = OutputHeads(d_model)
|
217 |
+
|
218 |
+
self.structure_encoder_fn = structure_encoder_fn
|
219 |
+
self.structure_decoder_fn = structure_decoder_fn
|
220 |
+
self.function_decoder_fn = function_decoder_fn
|
221 |
+
|
222 |
+
self._structure_encoder = None
|
223 |
+
self._structure_decoder = None
|
224 |
+
self._function_decoder = None
|
225 |
+
|
226 |
+
self.tokenizers = tokenizers
|
227 |
+
|
228 |
+
@classmethod
|
229 |
+
def from_pretrained(
|
230 |
+
cls, model_name: str = ESM3_OPEN_SMALL, device: torch.device | None = None
|
231 |
+
) -> ESM3:
|
232 |
+
from esm.pretrained import load_local_model
|
233 |
+
|
234 |
+
model_name = normalize_model_name(model_name)
|
235 |
+
if not model_name:
|
236 |
+
raise ValueError(f"Model name {model_name} is not a valid ESM3 model name.")
|
237 |
+
if device is None:
|
238 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
239 |
+
model = load_local_model(model_name, device=device)
|
240 |
+
if device.type != "cpu":
|
241 |
+
model = model.to(torch.bfloat16)
|
242 |
+
assert isinstance(model, ESM3)
|
243 |
+
return model
|
244 |
+
|
245 |
+
@property
|
246 |
+
def device(self):
|
247 |
+
return next(self.parameters()).device
|
248 |
+
|
249 |
+
@property
|
250 |
+
def raw_model(self):
|
251 |
+
return self
|
252 |
+
|
253 |
+
def get_structure_encoder(self) -> StructureTokenEncoder:
|
254 |
+
if self._structure_encoder is None:
|
255 |
+
self._structure_encoder = self.structure_encoder_fn(self.device)
|
256 |
+
return self._structure_encoder
|
257 |
+
|
258 |
+
def get_structure_decoder(self) -> StructureTokenDecoder:
|
259 |
+
if self._structure_decoder is None:
|
260 |
+
self._structure_decoder = self.structure_decoder_fn(self.device)
|
261 |
+
return self._structure_decoder
|
262 |
+
|
263 |
+
def get_function_decoder(self) -> FunctionTokenDecoder:
|
264 |
+
if self._function_decoder is None:
|
265 |
+
self._function_decoder = self.function_decoder_fn(self.device)
|
266 |
+
return self._function_decoder
|
267 |
+
|
268 |
+
def forward(
|
269 |
+
self,
|
270 |
+
*,
|
271 |
+
sequence_tokens: torch.Tensor | None = None,
|
272 |
+
structure_tokens: torch.Tensor | None = None,
|
273 |
+
ss8_tokens: torch.Tensor | None = None,
|
274 |
+
sasa_tokens: torch.Tensor | None = None,
|
275 |
+
function_tokens: torch.Tensor | None = None,
|
276 |
+
residue_annotation_tokens: torch.Tensor | None = None,
|
277 |
+
average_plddt: torch.Tensor | None = None,
|
278 |
+
per_res_plddt: torch.Tensor | None = None,
|
279 |
+
structure_coords: torch.Tensor | None = None,
|
280 |
+
chain_id: torch.Tensor | None = None,
|
281 |
+
sequence_id: torch.Tensor | None = None,
|
282 |
+
) -> ESMOutput:
|
283 |
+
"""
|
284 |
+
Performs forward pass through the ESM3 model. Check utils to see how to tokenize inputs from raw data.
|
285 |
+
|
286 |
+
Args:
|
287 |
+
sequence_tokens (torch.Tensor, optional): The amino acid tokens.
|
288 |
+
structure_tokens (torch.Tensor, optional): The structure tokens.
|
289 |
+
ss8_tokens (torch.Tensor, optional): The secondary structure tokens.
|
290 |
+
sasa_tokens (torch.Tensor, optional): The solvent accessible surface area tokens.
|
291 |
+
function_tokens (torch.Tensor, optional): The function tokens.
|
292 |
+
residue_annotation_tokens (torch.Tensor, optional): The residue annotation tokens.
|
293 |
+
average_plddt (torch.Tensor, optional): The average plddt across the entire sequence.
|
294 |
+
per_res_plddt (torch.Tensor, optional): The per residue plddt, if you want to specify exact plddts, use this,
|
295 |
+
otherwise, use average_plddt.
|
296 |
+
structure_coords (torch.Tensor, optional): The structure coordinates, in the form of (B, L, 3, 3).
|
297 |
+
chain_id (torch.Tensor, optional): The chain ID
|
298 |
+
sequence_id (torch.Tensor, optional): The sequence ID.
|
299 |
+
|
300 |
+
Returns:
|
301 |
+
ESMOutput: The output of the ESM3 model.
|
302 |
+
|
303 |
+
Raises:
|
304 |
+
ValueError: If at least one of the inputs is None.
|
305 |
+
|
306 |
+
"""
|
307 |
+
# Reasonable defaults:
|
308 |
+
try:
|
309 |
+
L, device = next(
|
310 |
+
(x.shape[1], x.device)
|
311 |
+
for x in [
|
312 |
+
sequence_tokens,
|
313 |
+
structure_tokens,
|
314 |
+
ss8_tokens,
|
315 |
+
sasa_tokens,
|
316 |
+
structure_coords,
|
317 |
+
function_tokens,
|
318 |
+
residue_annotation_tokens,
|
319 |
+
]
|
320 |
+
if x is not None
|
321 |
+
)
|
322 |
+
except StopIteration:
|
323 |
+
raise ValueError("At least one of the inputs must be non-None")
|
324 |
+
|
325 |
+
t = self.tokenizers
|
326 |
+
defaults = lambda x, tok: (
|
327 |
+
torch.full((1, L), tok, dtype=torch.long, device=device) if x is None else x
|
328 |
+
)
|
329 |
+
sequence_tokens = defaults(sequence_tokens, t.sequence.mask_token_id)
|
330 |
+
ss8_tokens = defaults(ss8_tokens, C.SS8_PAD_TOKEN)
|
331 |
+
sasa_tokens = defaults(sasa_tokens, C.SASA_PAD_TOKEN)
|
332 |
+
average_plddt = defaults(average_plddt, 1).float()
|
333 |
+
per_res_plddt = defaults(per_res_plddt, 0).float()
|
334 |
+
chain_id = defaults(chain_id, 0)
|
335 |
+
|
336 |
+
if residue_annotation_tokens is None:
|
337 |
+
residue_annotation_tokens = torch.full(
|
338 |
+
(1, L, 16), C.RESIDUE_PAD_TOKEN, dtype=torch.long, device=device
|
339 |
+
)
|
340 |
+
|
341 |
+
if function_tokens is None:
|
342 |
+
function_tokens = torch.full(
|
343 |
+
(1, L, 8), C.INTERPRO_PAD_TOKEN, dtype=torch.long, device=device
|
344 |
+
)
|
345 |
+
|
346 |
+
if structure_coords is None:
|
347 |
+
structure_coords = torch.full(
|
348 |
+
(1, L, 3, 3), float("nan"), dtype=torch.float, device=device
|
349 |
+
)
|
350 |
+
|
351 |
+
structure_coords = structure_coords[
|
352 |
+
..., :3, :
|
353 |
+
] # In case we pass in an atom14 or atom37 repr
|
354 |
+
affine, affine_mask = build_affine3d_from_coordinates(structure_coords)
|
355 |
+
|
356 |
+
structure_tokens = defaults(structure_tokens, C.STRUCTURE_MASK_TOKEN)
|
357 |
+
assert structure_tokens is not None
|
358 |
+
structure_tokens = (
|
359 |
+
structure_tokens.masked_fill(structure_tokens == -1, C.STRUCTURE_MASK_TOKEN)
|
360 |
+
.masked_fill(sequence_tokens == C.SEQUENCE_BOS_TOKEN, C.STRUCTURE_BOS_TOKEN)
|
361 |
+
.masked_fill(sequence_tokens == C.SEQUENCE_PAD_TOKEN, C.STRUCTURE_PAD_TOKEN)
|
362 |
+
.masked_fill(sequence_tokens == C.SEQUENCE_EOS_TOKEN, C.STRUCTURE_EOS_TOKEN)
|
363 |
+
.masked_fill(
|
364 |
+
sequence_tokens == C.SEQUENCE_CHAINBREAK_TOKEN,
|
365 |
+
C.STRUCTURE_CHAINBREAK_TOKEN,
|
366 |
+
)
|
367 |
+
)
|
368 |
+
|
369 |
+
x = self.encoder(
|
370 |
+
sequence_tokens,
|
371 |
+
structure_tokens,
|
372 |
+
average_plddt,
|
373 |
+
per_res_plddt,
|
374 |
+
ss8_tokens,
|
375 |
+
sasa_tokens,
|
376 |
+
function_tokens,
|
377 |
+
residue_annotation_tokens,
|
378 |
+
)
|
379 |
+
x, embedding, hidden = self.transformer(
|
380 |
+
x, sequence_id, affine, affine_mask, chain_id
|
381 |
+
)
|
382 |
+
return self.output_heads(x, embedding), hidden # MODIFIED FOR DYNA-1
|
383 |
+
|
384 |
+
# The following methods are for the ESM3InferenceClient interface
|
385 |
+
def generate(self, input: ProteinType, config: GenerationConfig) -> ProteinType:
|
386 |
+
"""Wrap around batched generation."""
|
387 |
+
proteins = self.batch_generate([input], [config])
|
388 |
+
assert len(proteins) == 1
|
389 |
+
return proteins[0]
|
390 |
+
|
391 |
+
def batch_generate(
|
392 |
+
self, inputs: list[ProteinType], configs: list[GenerationConfig]
|
393 |
+
) -> list[ProteinType]:
|
394 |
+
assert len(inputs) == len(
|
395 |
+
configs
|
396 |
+
), "Must have the same number of prompts and configs."
|
397 |
+
|
398 |
+
if inputs == []:
|
399 |
+
# Nothing to do.
|
400 |
+
return []
|
401 |
+
|
402 |
+
# Make sure prompts are of the same type.
|
403 |
+
t = type(inputs[0])
|
404 |
+
for i in range(1, len(inputs)):
|
405 |
+
assert isinstance(inputs[i], t), (
|
406 |
+
"Prompts must have the same type. Got "
|
407 |
+
f"{t.__name__ and type(inputs[i]).__name__} instead."
|
408 |
+
)
|
409 |
+
|
410 |
+
if isinstance(inputs[0], ESMProtein):
|
411 |
+
return iterative_sampling_raw(self, inputs, configs) # type: ignore
|
412 |
+
elif isinstance(inputs[0], ESMProteinTensor):
|
413 |
+
return iterative_sampling_tokens(
|
414 |
+
self,
|
415 |
+
inputs, # type: ignore
|
416 |
+
configs,
|
417 |
+
self.tokenizers, # type: ignore
|
418 |
+
)
|
419 |
+
else:
|
420 |
+
raise ValueError("Input must be an ESMProtein or ESMProteinTensor")
|
421 |
+
|
422 |
+
def encode(self, input: ESMProtein) -> ESMProteinTensor:
|
423 |
+
input = attr.evolve(input) # Make a copy
|
424 |
+
|
425 |
+
sequence_tokens = None
|
426 |
+
structure_tokens = None
|
427 |
+
secondary_structure_tokens = None
|
428 |
+
sasa_tokens = None
|
429 |
+
function_tokens = None
|
430 |
+
residue_annotation_tokens = None
|
431 |
+
|
432 |
+
coordinates = None
|
433 |
+
|
434 |
+
if input.sequence is not None:
|
435 |
+
sequence_tokens = encoding.tokenize_sequence(
|
436 |
+
input.sequence, self.tokenizers.sequence, add_special_tokens=True
|
437 |
+
)
|
438 |
+
if input.secondary_structure is not None:
|
439 |
+
secondary_structure_tokens = encoding.tokenize_secondary_structure(
|
440 |
+
input.secondary_structure,
|
441 |
+
self.tokenizers.secondary_structure,
|
442 |
+
add_special_tokens=True,
|
443 |
+
)
|
444 |
+
if input.sasa is not None:
|
445 |
+
sasa_tokens = encoding.tokenize_sasa(
|
446 |
+
input.sasa, self.tokenizers.sasa, add_special_tokens=True
|
447 |
+
)
|
448 |
+
|
449 |
+
# Infer input length
|
450 |
+
sequence_length = -1
|
451 |
+
if sequence_tokens is not None:
|
452 |
+
sequence_length = len(sequence_tokens)
|
453 |
+
elif secondary_structure_tokens is not None:
|
454 |
+
sequence_length = len(secondary_structure_tokens)
|
455 |
+
elif sasa_tokens is not None:
|
456 |
+
sequence_length = len(sasa_tokens)
|
457 |
+
|
458 |
+
# Try to infer input length from structure data
|
459 |
+
if input.coordinates is not None:
|
460 |
+
coordinates, _, structure_tokens = encoding.tokenize_structure(
|
461 |
+
input.coordinates,
|
462 |
+
self.get_structure_encoder(),
|
463 |
+
structure_tokenizer=self.tokenizers.structure,
|
464 |
+
reference_sequence=input.sequence or "",
|
465 |
+
add_special_tokens=True,
|
466 |
+
)
|
467 |
+
if sequence_length == -1:
|
468 |
+
sequence_length = len(structure_tokens)
|
469 |
+
|
470 |
+
if sequence_length == -1:
|
471 |
+
raise ValueError(
|
472 |
+
"Cannot infer input length from input data. Please provide one of: sequence, structure, secondary_structure, sasa.\n"
|
473 |
+
"To condition on sequence length only, use ESM3LocalInferenceClient.get_default_sequence(sequence_length) to generate a default sequence input."
|
474 |
+
)
|
475 |
+
|
476 |
+
# Function and Residue annotations
|
477 |
+
if input.function_annotations is not None:
|
478 |
+
if input.sequence is None:
|
479 |
+
reference_sequence = encoding.get_default_sequence(sequence_length - 2)
|
480 |
+
else:
|
481 |
+
reference_sequence = input.sequence
|
482 |
+
(function_tokens, residue_annotation_tokens) = (
|
483 |
+
encoding.tokenize_function_annotations(
|
484 |
+
input.function_annotations,
|
485 |
+
reference_sequence=reference_sequence,
|
486 |
+
function_tokenizer=self.tokenizers.function,
|
487 |
+
residue_annotation_tokenizer=self.tokenizers.residue_annotations,
|
488 |
+
add_special_tokens=True,
|
489 |
+
)
|
490 |
+
)
|
491 |
+
|
492 |
+
return ESMProteinTensor(
|
493 |
+
sequence=sequence_tokens,
|
494 |
+
structure=structure_tokens,
|
495 |
+
secondary_structure=secondary_structure_tokens,
|
496 |
+
sasa=sasa_tokens,
|
497 |
+
function=function_tokens,
|
498 |
+
residue_annotations=residue_annotation_tokens,
|
499 |
+
coordinates=coordinates,
|
500 |
+
).to(next(self.parameters()).device)
|
501 |
+
|
502 |
+
def decode(self, input: ESMProteinTensor) -> ESMProtein:
|
503 |
+
return decode_protein_tensor(
|
504 |
+
input=input,
|
505 |
+
tokenizers=self.tokenizers,
|
506 |
+
structure_token_decoder=self.get_structure_decoder(),
|
507 |
+
function_token_decoder=self.get_function_decoder(),
|
508 |
+
)
|
509 |
+
|
510 |
+
def logits(
|
511 |
+
self,
|
512 |
+
input: ESMProteinTensor | _BatchedESMProteinTensor,
|
513 |
+
config: LogitsConfig = LogitsConfig(),
|
514 |
+
) -> LogitsOutput:
|
515 |
+
if not isinstance(input, _BatchedESMProteinTensor):
|
516 |
+
# Create batch dimension if necessary.
|
517 |
+
input = _BatchedESMProteinTensor.from_protein_tensor(input)
|
518 |
+
|
519 |
+
device = torch.device(input.device)
|
520 |
+
|
521 |
+
# Default plddt conditioning for inference. 1s where coordinates are provided.
|
522 |
+
if input.coordinates is None:
|
523 |
+
per_res_plddt = None
|
524 |
+
else:
|
525 |
+
# 1.0 if all coordinates at specific indices have valid non-nan values.
|
526 |
+
per_res_plddt = input.coordinates.isfinite().all(dim=-1).any(dim=-1).float()
|
527 |
+
|
528 |
+
with (
|
529 |
+
torch.no_grad(), # Assume no gradients for now...
|
530 |
+
torch.autocast(enabled=True, device_type=device.type, dtype=torch.bfloat16) # type: ignore
|
531 |
+
if device.type == "cuda"
|
532 |
+
else contextlib.nullcontext(),
|
533 |
+
):
|
534 |
+
output = self.forward(
|
535 |
+
sequence_tokens=input.sequence,
|
536 |
+
structure_tokens=input.structure,
|
537 |
+
ss8_tokens=input.secondary_structure,
|
538 |
+
sasa_tokens=input.sasa,
|
539 |
+
function_tokens=input.function,
|
540 |
+
residue_annotation_tokens=input.residue_annotations,
|
541 |
+
average_plddt=torch.tensor(1.0, device=input.device),
|
542 |
+
per_res_plddt=per_res_plddt,
|
543 |
+
structure_coords=input.coordinates,
|
544 |
+
chain_id=None,
|
545 |
+
sequence_id=None,
|
546 |
+
)
|
547 |
+
|
548 |
+
output = ESMOutput(
|
549 |
+
**{k: v.to(device).to(torch.float32) for k, v in vars(output).items()}
|
550 |
+
)
|
551 |
+
|
552 |
+
return LogitsOutput(
|
553 |
+
logits=ForwardTrackData(
|
554 |
+
sequence=output.sequence_logits if config.sequence else None,
|
555 |
+
structure=output.structure_logits if config.structure else None,
|
556 |
+
secondary_structure=output.secondary_structure_logits
|
557 |
+
if config.secondary_structure
|
558 |
+
else None,
|
559 |
+
sasa=output.sasa_logits if config.sasa else None,
|
560 |
+
function=output.function_logits if config.function else None,
|
561 |
+
),
|
562 |
+
residue_annotation_logits=output.residue_logits
|
563 |
+
if config.residue_annotations
|
564 |
+
else None,
|
565 |
+
embeddings=output.embeddings if config.return_embeddings else None,
|
566 |
+
)
|
567 |
+
|
568 |
+
def forward_and_sample(
|
569 |
+
self, input: ESMProteinTensor, sampling_configuration: SamplingConfig
|
570 |
+
) -> ForwardAndSampleOutput:
|
571 |
+
validate_sampling_config(sampling_configuration, on_invalid="warn")
|
572 |
+
|
573 |
+
protein_tensor = attr.evolve(input) # Make a copy
|
574 |
+
|
575 |
+
device = next(self.parameters()).device
|
576 |
+
|
577 |
+
sampling_config = sampling_configuration
|
578 |
+
if sampling_config is None:
|
579 |
+
sampling_config = get_default_sampling_config(self.tokenizers)
|
580 |
+
|
581 |
+
# Initialize default values for missing tracks
|
582 |
+
default_protein_tensor = ESMProteinTensor.empty(
|
583 |
+
len(input) - 2, tokenizers=self.tokenizers, device=input.device
|
584 |
+
)
|
585 |
+
for track in attr.fields(ESMProteinTensor):
|
586 |
+
if getattr(protein_tensor, track.name, None) is None:
|
587 |
+
setattr(
|
588 |
+
protein_tensor,
|
589 |
+
track.name,
|
590 |
+
getattr(default_protein_tensor, track.name, None),
|
591 |
+
)
|
592 |
+
|
593 |
+
if len(protein_tensor) <= 0:
|
594 |
+
raise ValueError("No input data provided")
|
595 |
+
|
596 |
+
# Move input protein to proper device.
|
597 |
+
batched_protein = _BatchedESMProteinTensor.from_protein_tensor(protein_tensor)
|
598 |
+
batched_protein.to(device)
|
599 |
+
|
600 |
+
logits_output: LogitsOutput = _batch_forward(self, batched_protein)
|
601 |
+
forward_and_sample_out: ForwardAndSampleOutput = _sample_per_prompt(
|
602 |
+
batched_protein, logits_output, sampling_config, self.tokenizers
|
603 |
+
)
|
604 |
+
|
605 |
+
# There is only 1 prompt to sample for.
|
606 |
+
return _slice_tensor_dataclass(forward_and_sample_out, 0)
|
Dyna-1/esm/models/esmc.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import contextlib
|
4 |
+
|
5 |
+
import attr
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from attr import dataclass
|
9 |
+
|
10 |
+
from esm.layers.regression_head import RegressionHead
|
11 |
+
from esm.layers.transformer_stack import TransformerStack
|
12 |
+
from esm.sdk.api import (
|
13 |
+
ESMCInferenceClient,
|
14 |
+
ESMProtein,
|
15 |
+
ESMProteinTensor,
|
16 |
+
ForwardTrackData,
|
17 |
+
LogitsConfig,
|
18 |
+
LogitsOutput,
|
19 |
+
)
|
20 |
+
from esm.tokenization import EsmSequenceTokenizer
|
21 |
+
from esm.utils import encoding
|
22 |
+
from esm.utils.constants.models import ESMC_600M
|
23 |
+
from esm.utils.decoding import decode_sequence
|
24 |
+
from esm.utils.misc import stack_variable_length_tensors
|
25 |
+
from esm.utils.sampling import _BatchedESMProteinTensor
|
26 |
+
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
class ESMCOutput:
|
30 |
+
sequence_logits: torch.Tensor
|
31 |
+
embeddings: torch.Tensor | None
|
32 |
+
hidden_states: torch.Tensor | None
|
33 |
+
|
34 |
+
|
35 |
+
class ESMC(nn.Module, ESMCInferenceClient):
|
36 |
+
"""
|
37 |
+
ESMC model implementation.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
d_model (int): The dimensionality of the input and output feature vectors.
|
41 |
+
n_heads (int): The number of attention heads in the transformer layers.
|
42 |
+
n_layers (int): The number of transformer layers.
|
43 |
+
"""
|
44 |
+
|
45 |
+
def __init__(
|
46 |
+
self, d_model: int, n_heads: int, n_layers: int, tokenizer: EsmSequenceTokenizer
|
47 |
+
):
|
48 |
+
super().__init__()
|
49 |
+
self.embed = nn.Embedding(64, d_model)
|
50 |
+
self.transformer = TransformerStack(
|
51 |
+
d_model, n_heads, None, n_layers, n_layers_geom=0
|
52 |
+
)
|
53 |
+
self.sequence_head = RegressionHead(d_model, 64)
|
54 |
+
self.tokenizer = tokenizer
|
55 |
+
|
56 |
+
@classmethod
|
57 |
+
def from_pretrained(
|
58 |
+
cls, model_name: str = ESMC_600M, device: torch.device | None = None
|
59 |
+
) -> ESMC:
|
60 |
+
from esm.pretrained import load_local_model
|
61 |
+
|
62 |
+
if device is None:
|
63 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
64 |
+
model = load_local_model(model_name, device=device)
|
65 |
+
if device.type != "cpu":
|
66 |
+
model = model.to(torch.bfloat16)
|
67 |
+
assert isinstance(model, ESMC)
|
68 |
+
return model
|
69 |
+
|
70 |
+
@property
|
71 |
+
def device(self):
|
72 |
+
return next(self.parameters()).device
|
73 |
+
|
74 |
+
@property
|
75 |
+
def raw_model(self):
|
76 |
+
return self
|
77 |
+
|
78 |
+
def _tokenize(self, sequence: list[str]) -> torch.Tensor:
|
79 |
+
pad = self.tokenizer.pad_token_id
|
80 |
+
assert pad is not None
|
81 |
+
return stack_variable_length_tensors(
|
82 |
+
[
|
83 |
+
encoding.tokenize_sequence(x, self.tokenizer, add_special_tokens=True)
|
84 |
+
for x in sequence
|
85 |
+
],
|
86 |
+
constant_value=pad,
|
87 |
+
).to(next(self.parameters()).device)
|
88 |
+
|
89 |
+
def _detokenize(self, sequence: torch.Tensor) -> list[str]:
|
90 |
+
pad = self.tokenizer.pad_token_id
|
91 |
+
assert pad is not None
|
92 |
+
assert sequence.ndim == 2
|
93 |
+
return [decode_sequence(x[x != pad][1:-1], self.tokenizer) for x in sequence]
|
94 |
+
|
95 |
+
def forward(
|
96 |
+
self,
|
97 |
+
sequence_tokens: torch.Tensor | None = None,
|
98 |
+
sequence_id: torch.Tensor | None = None,
|
99 |
+
) -> ESMCOutput:
|
100 |
+
"""
|
101 |
+
Performs forward pass through the ESMC model. Check utils to see how to tokenize inputs from raw data.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
sequence_tokens (torch.Tensor, optional): The amino acid tokens.
|
105 |
+
sequence_id (torch.Tensor, optional): The sequence ID.
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
ESMCOutput: The output of the ESMC model.
|
109 |
+
|
110 |
+
"""
|
111 |
+
if sequence_id is None:
|
112 |
+
sequence_id = sequence_tokens == self.tokenizer.pad_token_id
|
113 |
+
|
114 |
+
x = self.embed(sequence_tokens)
|
115 |
+
x, _, hidden = self.transformer(x, sequence_id=sequence_id)
|
116 |
+
sequence_logits = self.sequence_head(x)
|
117 |
+
output = ESMCOutput(
|
118 |
+
sequence_logits=sequence_logits, embeddings=x, hidden_states=hidden
|
119 |
+
)
|
120 |
+
return output, hidden #MODIFIED FOR DYNA-1
|
121 |
+
|
122 |
+
def encode(self, input: ESMProtein) -> ESMProteinTensor:
|
123 |
+
input = attr.evolve(input) # Make a copy
|
124 |
+
sequence_tokens = None
|
125 |
+
|
126 |
+
if input.sequence is not None:
|
127 |
+
sequence_tokens = self._tokenize([input.sequence])[0]
|
128 |
+
return ESMProteinTensor(sequence=sequence_tokens).to(
|
129 |
+
next(self.parameters()).device
|
130 |
+
)
|
131 |
+
|
132 |
+
def decode(self, input: ESMProteinTensor) -> ESMProtein:
|
133 |
+
input = attr.evolve(input) # Make a copy
|
134 |
+
|
135 |
+
assert input.sequence is not None
|
136 |
+
sequence = self._detokenize(input.sequence)[0]
|
137 |
+
|
138 |
+
return ESMProtein(sequence=sequence)
|
139 |
+
|
140 |
+
def logits(
|
141 |
+
self,
|
142 |
+
input: ESMProteinTensor | _BatchedESMProteinTensor,
|
143 |
+
config: LogitsConfig = LogitsConfig(),
|
144 |
+
) -> LogitsOutput:
|
145 |
+
if not isinstance(input, _BatchedESMProteinTensor):
|
146 |
+
# Create batch dimension if necessary.
|
147 |
+
input = _BatchedESMProteinTensor.from_protein_tensor(input)
|
148 |
+
|
149 |
+
device = torch.device(input.device)
|
150 |
+
|
151 |
+
with (
|
152 |
+
torch.no_grad(),
|
153 |
+
torch.autocast(enabled=True, device_type=device.type, dtype=torch.bfloat16) # type: ignore
|
154 |
+
if device.type == "cuda"
|
155 |
+
else contextlib.nullcontext(),
|
156 |
+
):
|
157 |
+
output = self.forward(sequence_tokens=input.sequence)
|
158 |
+
|
159 |
+
return LogitsOutput(
|
160 |
+
logits=ForwardTrackData(
|
161 |
+
sequence=output.sequence_logits if config.sequence else None
|
162 |
+
),
|
163 |
+
embeddings=output.embeddings if config.return_embeddings else None,
|
164 |
+
)
|
Dyna-1/esm/models/function_decoder.py
ADDED
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Function Token Decoder."""
|
2 |
+
|
3 |
+
from collections import defaultdict
|
4 |
+
from dataclasses import dataclass, field
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from cloudpathlib import AnyPath
|
12 |
+
|
13 |
+
from esm.layers.regression_head import RegressionHead
|
14 |
+
from esm.layers.transformer_stack import TransformerStack
|
15 |
+
from esm.tokenization.function_tokenizer import (
|
16 |
+
InterProQuantizedTokenizer,
|
17 |
+
)
|
18 |
+
from esm.utils.constants import esm3 as C
|
19 |
+
from esm.utils.misc import merge_annotations, merge_ranges
|
20 |
+
from esm.utils.types import FunctionAnnotation
|
21 |
+
|
22 |
+
|
23 |
+
@dataclass(frozen=True)
|
24 |
+
class FunctionTokenDecoderConfig:
|
25 |
+
"""Configures function token decoder."""
|
26 |
+
|
27 |
+
# Embedding dimension of decoder.
|
28 |
+
d_model: int = 1024
|
29 |
+
# Number of attention heads of decoder.
|
30 |
+
n_heads: int = 8
|
31 |
+
# Number of layers of decoder.
|
32 |
+
n_layers: int = 3
|
33 |
+
# Number of integer values that function tokens may assume.
|
34 |
+
function_token_vocab_size: int = 260
|
35 |
+
# Number of function tokens at each position.
|
36 |
+
function_token_depth: int = 8
|
37 |
+
# Number of InterPro labels that can be decoded.
|
38 |
+
num_interpro_classes: int = 29026
|
39 |
+
# Number of function keywords that can be decoded.
|
40 |
+
keyword_vocabulary_size: int = 58641
|
41 |
+
# List of supported InterPro ids.
|
42 |
+
interpro_entry_list: str = field(default_factory=lambda: str(C.INTERPRO_ENTRY))
|
43 |
+
# Path to keywords vocabulary.
|
44 |
+
keyword_vocabulary_path: str = field(
|
45 |
+
default_factory=lambda: str(C.data_root("esm3") / C.KEYWORDS_VOCABULARY)
|
46 |
+
)
|
47 |
+
# Whether to unpack LSH bits into single-bit tokens.
|
48 |
+
unpack_lsh_bits: bool = True
|
49 |
+
# The number of special tokens in the function tokenizer vocabulary which come
|
50 |
+
# before the LSH tokens.
|
51 |
+
num_special_tokens: int = 4
|
52 |
+
# The number of bits per LSH token in the function tokenizer.
|
53 |
+
bits_per_token: int = 8
|
54 |
+
|
55 |
+
|
56 |
+
class FunctionTokenDecoder(nn.Module):
|
57 |
+
def __init__(self, config: FunctionTokenDecoderConfig | None = None):
|
58 |
+
"""Constructs function token decoder."""
|
59 |
+
super().__init__()
|
60 |
+
if config is None:
|
61 |
+
config = FunctionTokenDecoderConfig()
|
62 |
+
self.config = config
|
63 |
+
|
64 |
+
# Get the supported set of InterPro ids.
|
65 |
+
with AnyPath(config.interpro_entry_list).open("r") as f:
|
66 |
+
df = pd.read_csv(f, sep="\t")
|
67 |
+
self.interpro_ids = sorted(df.ENTRY_AC)
|
68 |
+
self.interpro2index = {
|
69 |
+
interpro_id: i for i, interpro_id in enumerate(self.interpro_ids)
|
70 |
+
}
|
71 |
+
assert len(self.interpro_ids) == config.num_interpro_classes
|
72 |
+
|
73 |
+
with AnyPath(config.keyword_vocabulary_path).open("r") as f:
|
74 |
+
self.keywords_vocabulary: list[str] = list(f.read().strip().split("\n"))
|
75 |
+
assert len(self.keywords_vocabulary) == config.keyword_vocabulary_size
|
76 |
+
|
77 |
+
if config.unpack_lsh_bits:
|
78 |
+
vocab_size = 2 * config.function_token_depth * config.bits_per_token
|
79 |
+
else:
|
80 |
+
# Function-token id's re-use the same token ids at each position along the depth
|
81 |
+
# dimension, despite distinct meanings. The decoder should take this into
|
82 |
+
# account so create distinct embeddings for tokens at each position.
|
83 |
+
vocab_size = (
|
84 |
+
self.config.function_token_depth * self.config.function_token_vocab_size
|
85 |
+
)
|
86 |
+
|
87 |
+
self.embedding = nn.Embedding(
|
88 |
+
# Function-token id's re-use the same token ids at each position along the
|
89 |
+
# depth dimension, despite distinct meanings. The decoder should take this
|
90 |
+
# into account so create distinct embeddings for tokens at each position.
|
91 |
+
num_embeddings=(vocab_size),
|
92 |
+
embedding_dim=config.d_model,
|
93 |
+
)
|
94 |
+
self.decoder = TransformerStack(
|
95 |
+
d_model=config.d_model,
|
96 |
+
n_heads=config.n_heads,
|
97 |
+
v_heads=None,
|
98 |
+
n_layers=config.n_layers,
|
99 |
+
n_layers_geom=0,
|
100 |
+
scale_residue=False,
|
101 |
+
bias=True,
|
102 |
+
qk_layernorm=False,
|
103 |
+
ffn_type="gelu",
|
104 |
+
expansion_ratio=4,
|
105 |
+
)
|
106 |
+
self.heads = nn.ModuleDict(
|
107 |
+
{
|
108 |
+
# Binary classification head predicting which keywords are present.
|
109 |
+
"keyword_logits": RegressionHead(
|
110 |
+
d_model=config.d_model,
|
111 |
+
output_dim=config.keyword_vocabulary_size,
|
112 |
+
hidden_dim=4 * config.d_model,
|
113 |
+
),
|
114 |
+
# Regresses the TF-IDF value of each present keyword.
|
115 |
+
"keyword_tfidf": RegressionHead(
|
116 |
+
d_model=config.d_model,
|
117 |
+
output_dim=config.keyword_vocabulary_size,
|
118 |
+
hidden_dim=4 * config.d_model,
|
119 |
+
),
|
120 |
+
# Predicts which InterPro annotations are present.
|
121 |
+
"interpro_logits": RegressionHead(
|
122 |
+
d_model=config.d_model,
|
123 |
+
output_dim=config.num_interpro_classes,
|
124 |
+
hidden_dim=4 * config.d_model,
|
125 |
+
),
|
126 |
+
}
|
127 |
+
)
|
128 |
+
|
129 |
+
def forward(self, token_ids: torch.Tensor) -> dict[str, torch.Tensor]:
|
130 |
+
"""Forward pass through function token decoder.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
token_ids: <int>[batch_size, function_token_depth] batch of function tokens
|
134 |
+
ids to decode.
|
135 |
+
Returns:
|
136 |
+
interpro_logits: binary classification logits tensor of shape
|
137 |
+
<float>[batch_size, num_interpro_classes]
|
138 |
+
"""
|
139 |
+
assert token_ids.ndim == 2
|
140 |
+
assert token_ids.shape[1] == self.config.function_token_depth
|
141 |
+
batch_size, depth = token_ids.shape
|
142 |
+
|
143 |
+
if self.config.unpack_lsh_bits:
|
144 |
+
# Shift values into [0, 2^bits/token)
|
145 |
+
lsh_bits = token_ids - self.config.num_special_tokens
|
146 |
+
# extract each bit. (hob stands for highest-order bit)
|
147 |
+
bits = torch.concat(
|
148 |
+
[
|
149 |
+
torch.bitwise_and(lsh_bits, 1 << hob).gt(0).to(torch.int32)
|
150 |
+
for hob in range(self.config.bits_per_token)
|
151 |
+
],
|
152 |
+
dim=1,
|
153 |
+
)
|
154 |
+
assert bits.shape == (batch_size, depth * self.config.bits_per_token)
|
155 |
+
|
156 |
+
# Shift each bit into individual vocabulary ranges, so they get distinct
|
157 |
+
# embeddings.
|
158 |
+
vocab_offsets = 2 * torch.arange(
|
159 |
+
depth * self.config.bits_per_token, device=token_ids.device
|
160 |
+
)
|
161 |
+
inputs = vocab_offsets[None, :] + bits
|
162 |
+
|
163 |
+
# zero-out special tokens, i.e. non LSH tokens.
|
164 |
+
where_special = token_ids < self.config.num_special_tokens
|
165 |
+
inputs = torch.where(where_special.any(dim=1, keepdim=True), 0, inputs)
|
166 |
+
else:
|
167 |
+
# Apply depth-position offset to use distinct vocabs. See __init__ for
|
168 |
+
# explaination.
|
169 |
+
vocab_offsets = self.config.function_token_vocab_size * torch.arange(
|
170 |
+
self.config.function_token_depth, device=token_ids.device
|
171 |
+
)
|
172 |
+
inputs = token_ids + vocab_offsets[None, :]
|
173 |
+
|
174 |
+
embed = self.embedding(inputs)
|
175 |
+
encoding, _, _ = self.decoder(embed)
|
176 |
+
pooled = torch.mean(encoding, dim=1)
|
177 |
+
|
178 |
+
return {name: head(pooled) for name, head in self.heads.items()}
|
179 |
+
|
180 |
+
@property
|
181 |
+
def device(self) -> torch.device:
|
182 |
+
return next(self.parameters()).device
|
183 |
+
|
184 |
+
def decode(
|
185 |
+
self,
|
186 |
+
function_token_ids: torch.Tensor,
|
187 |
+
tokenizer: InterProQuantizedTokenizer,
|
188 |
+
decode_annotations: bool = True,
|
189 |
+
annotation_threshold: float = 0.1,
|
190 |
+
decode_keywords=True,
|
191 |
+
keywords_threshold: float = 0.5,
|
192 |
+
annotation_min_length: int | None = 5,
|
193 |
+
annotation_gap_merge_max: int | None = 3,
|
194 |
+
):
|
195 |
+
"""Decodes function tokens into predicted annotations and keywords.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
function_token_ids: <int>[length, depth] function token ids. NOTE:
|
199 |
+
without <bos>/<eos> prefix
|
200 |
+
tokenizer: function tokenizer.
|
201 |
+
decode_annotations: whether to decode InterPro annotations.
|
202 |
+
annotation_threshold: threshold for emitting a function annotation.
|
203 |
+
decode_keywords: whether to decode function keywords.
|
204 |
+
keywords_threshold: threshold for emitting a keyword.
|
205 |
+
annotation_min_length: optional minimum length of predicted annotations for
|
206 |
+
size filtering.
|
207 |
+
annotation_gap_merge_max: optional merge adjacent annotation of the same type
|
208 |
+
Returns:
|
209 |
+
Decoder outputs:
|
210 |
+
- "interpro_logits": <float>[length, num_interpro] predicted interpro logits.
|
211 |
+
- "interpro_preds": <bool>[length, num_interpro] predicted intepro labels.
|
212 |
+
- "interpro_annotations": list[FunctionAnnotation] predicted InterPro
|
213 |
+
annotations
|
214 |
+
- "keyword_logits": <float>[length, keyword_vocabulary] binary prediciton
|
215 |
+
logits for keywrods.
|
216 |
+
- "function_keywords": list[FunctionAnnotation] predicted function keyword
|
217 |
+
ranges.
|
218 |
+
"""
|
219 |
+
assert function_token_ids.ndim == 2
|
220 |
+
assert function_token_ids.shape[1] == tokenizer.depth
|
221 |
+
assert self.config.function_token_depth == tokenizer.depth
|
222 |
+
|
223 |
+
outputs = {}
|
224 |
+
|
225 |
+
outputs = self(function_token_ids.to(self.device))
|
226 |
+
|
227 |
+
# Only decode in positions that have function tokens.
|
228 |
+
where_decode = torch.all(
|
229 |
+
(function_token_ids != tokenizer.vocab_to_index["<pad>"])
|
230 |
+
& (function_token_ids != tokenizer.vocab_to_index["<none>"])
|
231 |
+
& (function_token_ids != tokenizer.vocab_to_index["<unk>"]),
|
232 |
+
dim=1,
|
233 |
+
)
|
234 |
+
|
235 |
+
# Decode InterPro annotations ranges.
|
236 |
+
interpro_preds = F.sigmoid(outputs["interpro_logits"])
|
237 |
+
interpro_preds = interpro_preds >= annotation_threshold
|
238 |
+
interpro_preds[~where_decode, :] = False
|
239 |
+
outputs["interpro_preds"] = interpro_preds
|
240 |
+
if decode_annotations:
|
241 |
+
annotations: list[FunctionAnnotation] = []
|
242 |
+
preds: np.ndarray = interpro_preds.detach().cpu().numpy()
|
243 |
+
for position_index, class_index in zip(*preds.nonzero()):
|
244 |
+
interpro_id = self.interpro_ids[class_index]
|
245 |
+
annotation = FunctionAnnotation(
|
246 |
+
label=interpro_id,
|
247 |
+
start=position_index, # one-index inclusive (BOS shifts indexes +1)
|
248 |
+
end=position_index, # one-index inclusive
|
249 |
+
)
|
250 |
+
annotations.append(annotation)
|
251 |
+
|
252 |
+
annotations = merge_annotations(
|
253 |
+
annotations, merge_gap_max=annotation_gap_merge_max
|
254 |
+
)
|
255 |
+
|
256 |
+
# Drop very small annotations.
|
257 |
+
if annotation_min_length is not None:
|
258 |
+
annotations = [
|
259 |
+
annotation
|
260 |
+
for annotation in annotations
|
261 |
+
if annotation.end - annotation.start + 1 >= annotation_min_length
|
262 |
+
]
|
263 |
+
|
264 |
+
outputs["interpro_annotations"] = annotations
|
265 |
+
|
266 |
+
# Decode function keyword ranges.
|
267 |
+
keyword_logits = outputs["keyword_logits"]
|
268 |
+
keyword_logits[~where_decode, :] = -torch.inf
|
269 |
+
if decode_keywords:
|
270 |
+
keyword_preds = F.sigmoid(keyword_logits) >= keywords_threshold
|
271 |
+
outputs["function_keywords"] = self._preds_to_keywords(
|
272 |
+
keyword_preds.detach().cpu().numpy()
|
273 |
+
)
|
274 |
+
|
275 |
+
return outputs
|
276 |
+
|
277 |
+
def _preds_to_keywords(self, keyword_preds: np.ndarray) -> list[FunctionAnnotation]:
|
278 |
+
"""Converts output log-TFDF to predicted keywords over the sequence.
|
279 |
+
|
280 |
+
Args:
|
281 |
+
keyword_precs: <bool>[length, keyword_vocab] positional predictions of
|
282 |
+
function keywords from the keyword prediction head.
|
283 |
+
Returns:
|
284 |
+
Non-overlapping keyword annotated ranges along the sequence. Note that indices
|
285 |
+
will index into the *sequence*, not the function token array which has a
|
286 |
+
<pad> prefix.
|
287 |
+
"""
|
288 |
+
assert keyword_preds.ndim == 2
|
289 |
+
assert keyword_preds.shape[1] == self.config.keyword_vocabulary_size
|
290 |
+
|
291 |
+
keyword_positions: dict[str, list[range]] = defaultdict(list)
|
292 |
+
for position, keyword_id in zip(*np.nonzero(keyword_preds)):
|
293 |
+
keyword = self.keywords_vocabulary[keyword_id]
|
294 |
+
keyword_positions[keyword].append(range(position, position + 1))
|
295 |
+
|
296 |
+
annotations: list[FunctionAnnotation] = []
|
297 |
+
for keyword, ranges in keyword_positions.items():
|
298 |
+
for range_ in merge_ranges(ranges):
|
299 |
+
annotation = FunctionAnnotation(
|
300 |
+
label=keyword,
|
301 |
+
start=range_.start, # one-index inclusive (BOS shifts indexes +1)
|
302 |
+
end=range_.stop - 1, # one-index exclusive -> one-index inclusive
|
303 |
+
)
|
304 |
+
annotations.append(annotation)
|
305 |
+
|
306 |
+
return annotations
|
Dyna-1/esm/models/vqvae.py
ADDED
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from esm.layers.blocks import UnifiedTransformerBlock
|
5 |
+
from esm.layers.codebook import EMACodebook
|
6 |
+
from esm.layers.structure_proj import Dim6RotStructureHead
|
7 |
+
from esm.layers.transformer_stack import TransformerStack
|
8 |
+
from esm.utils.constants import esm3 as C
|
9 |
+
from esm.utils.misc import knn_graph
|
10 |
+
from esm.utils.structure.affine3d import (
|
11 |
+
Affine3D,
|
12 |
+
build_affine3d_from_coordinates,
|
13 |
+
)
|
14 |
+
from esm.utils.structure.predicted_aligned_error import (
|
15 |
+
compute_predicted_aligned_error,
|
16 |
+
compute_tm,
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
class RelativePositionEmbedding(nn.Module):
|
21 |
+
"""
|
22 |
+
Embedding layer for relative position embeddings. `bins` is the number of positions relative
|
23 |
+
to the query position that are considered before clipping. For instance, if `bins=10`, then
|
24 |
+
the relative position embedding will have 21 positions, [-10, 10].
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, bins, embedding_dim, init_std=0.02):
|
28 |
+
super().__init__()
|
29 |
+
self.bins = bins
|
30 |
+
|
31 |
+
self.embedding = torch.nn.Embedding(2 * bins + 2, embedding_dim)
|
32 |
+
self.embedding.weight.data.normal_(0, init_std)
|
33 |
+
|
34 |
+
def forward(self, query_residue_index, key_residue_index):
|
35 |
+
"""
|
36 |
+
Input:
|
37 |
+
query_residue_index: (B, ) tensor of source indices (dytpe=torch.long)
|
38 |
+
key_residue_index: (B, L) tensor of target indices (dytpe=torch.long)
|
39 |
+
Output:
|
40 |
+
embeddings: B x L x embedding_dim tensor of embeddings
|
41 |
+
"""
|
42 |
+
|
43 |
+
assert query_residue_index.dtype == torch.long
|
44 |
+
assert key_residue_index.dtype == torch.long
|
45 |
+
assert query_residue_index.ndim == 1
|
46 |
+
assert key_residue_index.ndim == 2
|
47 |
+
|
48 |
+
diff = key_residue_index - query_residue_index.unsqueeze(1)
|
49 |
+
diff = diff.clamp(-self.bins, self.bins)
|
50 |
+
diff = diff + self.bins + 1 # add 1 to adjust for padding index
|
51 |
+
output = self.embedding(diff)
|
52 |
+
return output
|
53 |
+
|
54 |
+
|
55 |
+
class PairwisePredictionHead(nn.Module):
|
56 |
+
def __init__(
|
57 |
+
self,
|
58 |
+
input_dim: int,
|
59 |
+
downproject_dim: int,
|
60 |
+
hidden_dim: int,
|
61 |
+
n_bins: int,
|
62 |
+
bias: bool = True,
|
63 |
+
pairwise_state_dim: int = 0,
|
64 |
+
):
|
65 |
+
super().__init__()
|
66 |
+
self.downproject = nn.Linear(input_dim, downproject_dim, bias=bias)
|
67 |
+
self.linear1 = nn.Linear(
|
68 |
+
downproject_dim + pairwise_state_dim, hidden_dim, bias=bias
|
69 |
+
)
|
70 |
+
self.activation_fn = nn.GELU()
|
71 |
+
self.norm = nn.LayerNorm(hidden_dim)
|
72 |
+
self.linear2 = nn.Linear(hidden_dim, n_bins, bias=bias)
|
73 |
+
|
74 |
+
def forward(self, x, pairwise: torch.Tensor | None = None):
|
75 |
+
"""
|
76 |
+
Args:
|
77 |
+
x: [B x L x D]
|
78 |
+
|
79 |
+
Output:
|
80 |
+
[B x L x L x K]
|
81 |
+
"""
|
82 |
+
x = self.downproject(x)
|
83 |
+
# Let x_i be a vector of size (B, D).
|
84 |
+
# Input is {x_1, ..., x_L} of size (B, L, D)
|
85 |
+
# Output is 2D where x_ij = cat([x_i * x_j, x_i - x_j])
|
86 |
+
q, k = x.chunk(2, dim=-1)
|
87 |
+
|
88 |
+
prod = q[:, None, :, :] * k[:, :, None, :]
|
89 |
+
diff = q[:, None, :, :] - k[:, :, None, :]
|
90 |
+
x_2d = [prod, diff]
|
91 |
+
if pairwise is not None:
|
92 |
+
x_2d.append(pairwise)
|
93 |
+
x = torch.cat(x_2d, dim=-1)
|
94 |
+
x = self.linear1(x)
|
95 |
+
x = self.activation_fn(x)
|
96 |
+
x = self.norm(x)
|
97 |
+
x = self.linear2(x)
|
98 |
+
return x
|
99 |
+
|
100 |
+
|
101 |
+
class RegressionHead(nn.Module):
|
102 |
+
def __init__(self, embed_dim: int, output_dim: int):
|
103 |
+
super().__init__()
|
104 |
+
self.dense = nn.Linear(embed_dim, embed_dim)
|
105 |
+
self.activation_fn = nn.GELU()
|
106 |
+
self.norm = nn.LayerNorm(embed_dim)
|
107 |
+
self.output = nn.Linear(embed_dim, output_dim)
|
108 |
+
|
109 |
+
def forward(self, features):
|
110 |
+
x = self.dense(features)
|
111 |
+
x = self.activation_fn(x)
|
112 |
+
x = self.norm(x)
|
113 |
+
x = self.output(x)
|
114 |
+
return x
|
115 |
+
|
116 |
+
|
117 |
+
class CategoricalMixture:
|
118 |
+
def __init__(self, param, bins=50, start=0, end=1):
|
119 |
+
# All tensors are of shape ..., bins.
|
120 |
+
self.logits = param
|
121 |
+
bins = torch.linspace(
|
122 |
+
start, end, bins + 1, device=self.logits.device, dtype=torch.float32
|
123 |
+
)
|
124 |
+
self.v_bins = (bins[:-1] + bins[1:]) / 2
|
125 |
+
|
126 |
+
def log_prob(self, true):
|
127 |
+
# Shapes are:
|
128 |
+
# self.probs: ... x bins
|
129 |
+
# true : ... (floating point # for target)
|
130 |
+
true_index = (
|
131 |
+
(true.unsqueeze(-1) - self.v_bins[[None] * true.ndim]).abs().argmin(-1)
|
132 |
+
)
|
133 |
+
nll = self.logits.log_softmax(-1)
|
134 |
+
return torch.take_along_dim(nll, true_index.unsqueeze(-1), dim=-1).squeeze(-1)
|
135 |
+
|
136 |
+
def mean(self):
|
137 |
+
return (
|
138 |
+
self.logits.to(self.v_bins.dtype).softmax(-1) @ self.v_bins.unsqueeze(1)
|
139 |
+
).squeeze(-1)
|
140 |
+
|
141 |
+
def median(self):
|
142 |
+
return self.v_bins[self.logits.max(-1).indices]
|
143 |
+
|
144 |
+
|
145 |
+
class GeometricEncoderStack(TransformerStack):
|
146 |
+
def __init__(self, d_model, n_heads, v_heads, n_layers):
|
147 |
+
super().__init__(d_model, n_heads, v_heads, 0)
|
148 |
+
self.blocks = nn.ModuleList(
|
149 |
+
[
|
150 |
+
UnifiedTransformerBlock(
|
151 |
+
d_model,
|
152 |
+
n_heads,
|
153 |
+
v_heads=v_heads,
|
154 |
+
use_geom_attn=True,
|
155 |
+
use_plain_attn=False,
|
156 |
+
expansion_ratio=4,
|
157 |
+
bias=True,
|
158 |
+
)
|
159 |
+
for i in range(n_layers)
|
160 |
+
]
|
161 |
+
)
|
162 |
+
self.norm = nn.Identity()
|
163 |
+
|
164 |
+
|
165 |
+
def batched_gather(data, inds, dim=0, no_batch_dims=0):
|
166 |
+
ranges = []
|
167 |
+
for i, s in enumerate(data.shape[:no_batch_dims]):
|
168 |
+
r = torch.arange(s)
|
169 |
+
r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
|
170 |
+
ranges.append(r)
|
171 |
+
|
172 |
+
remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)]
|
173 |
+
remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
|
174 |
+
ranges.extend(remaining_dims)
|
175 |
+
return data[ranges]
|
176 |
+
|
177 |
+
|
178 |
+
def node_gather(s: torch.Tensor, edges: torch.Tensor) -> torch.Tensor:
|
179 |
+
return batched_gather(s.unsqueeze(-3), edges, -2, no_batch_dims=len(s.shape) - 1)
|
180 |
+
|
181 |
+
|
182 |
+
class StructureTokenEncoder(nn.Module):
|
183 |
+
def __init__(self, d_model, n_heads, v_heads, n_layers, d_out, n_codes):
|
184 |
+
super().__init__()
|
185 |
+
# We only support fully-geometric structure token encoders for now...
|
186 |
+
# setting n_layers_geom to something that's not n_layers won't work because
|
187 |
+
# sequence ID isn't supported fully in this repo for plain-old transformers
|
188 |
+
self.transformer = GeometricEncoderStack(d_model, n_heads, v_heads, n_layers)
|
189 |
+
self.pre_vq_proj = nn.Linear(d_model, d_out)
|
190 |
+
self.codebook = EMACodebook(n_codes, d_out)
|
191 |
+
self.relative_positional_embedding = RelativePositionEmbedding(
|
192 |
+
32, d_model, init_std=0.02
|
193 |
+
)
|
194 |
+
self.knn = 16
|
195 |
+
|
196 |
+
def encode_local_structure(
|
197 |
+
self,
|
198 |
+
coords: torch.Tensor,
|
199 |
+
affine: Affine3D,
|
200 |
+
attention_mask: torch.Tensor,
|
201 |
+
sequence_id: torch.Tensor | None,
|
202 |
+
affine_mask: torch.Tensor,
|
203 |
+
residue_index: torch.Tensor | None = None,
|
204 |
+
):
|
205 |
+
"""This function allows for a multi-layered encoder to encode tokens with a local receptive fields. The implementation is as follows:
|
206 |
+
|
207 |
+
1. Starting with (B, L) frames, we find the KNN in structure space. This now gives us (B, L, K) where the last dimension is the local
|
208 |
+
neighborhood of all (B, L) residues.
|
209 |
+
2. We reshape these frames to (B*L, K) so now we have a large batch of a bunch of local neighborhoods.
|
210 |
+
3. Pass the (B*L, K) local neighborhoods through a stack of geometric reasoning blocks, effectively getting all to all communication between
|
211 |
+
all frames in the local neighborhood.
|
212 |
+
4. This gives (B*L, K, d_model) embeddings, from which we need to get a single embedding per local neighborhood. We do this by simply
|
213 |
+
taking the embedding corresponding to the query node. This gives us (B*L, d_model) embeddings.
|
214 |
+
5. Reshape back to (B, L, d_model) embeddings
|
215 |
+
"""
|
216 |
+
assert coords.size(-1) == 3 and coords.size(-2) == 3, "need N, CA, C"
|
217 |
+
with torch.no_grad():
|
218 |
+
knn_edges, _ = self.find_knn_edges(
|
219 |
+
coords,
|
220 |
+
~attention_mask,
|
221 |
+
coord_mask=affine_mask,
|
222 |
+
sequence_id=sequence_id,
|
223 |
+
knn=self.knn,
|
224 |
+
)
|
225 |
+
B, L, E = knn_edges.shape
|
226 |
+
|
227 |
+
affine_tensor = affine.tensor # for easier manipulation
|
228 |
+
T_D = affine_tensor.size(-1)
|
229 |
+
knn_affine_tensor = node_gather(affine_tensor, knn_edges)
|
230 |
+
knn_affine_tensor = knn_affine_tensor.view(-1, E, T_D).contiguous()
|
231 |
+
affine = Affine3D.from_tensor(knn_affine_tensor)
|
232 |
+
knn_sequence_id = (
|
233 |
+
node_gather(sequence_id.unsqueeze(-1), knn_edges).view(-1, E)
|
234 |
+
if sequence_id is not None
|
235 |
+
else torch.zeros(B * L, E, dtype=torch.int64, device=coords.device)
|
236 |
+
)
|
237 |
+
knn_affine_mask = node_gather(affine_mask.unsqueeze(-1), knn_edges).view(
|
238 |
+
-1, E
|
239 |
+
)
|
240 |
+
knn_chain_id = torch.zeros(
|
241 |
+
B * L, E, dtype=torch.int64, device=coords.device
|
242 |
+
)
|
243 |
+
|
244 |
+
if residue_index is None:
|
245 |
+
res_idxs = knn_edges.view(-1, E)
|
246 |
+
else:
|
247 |
+
res_idxs = node_gather(residue_index.unsqueeze(-1), knn_edges).view(
|
248 |
+
-1, E
|
249 |
+
)
|
250 |
+
|
251 |
+
z = self.relative_positional_embedding(res_idxs[:, 0], res_idxs)
|
252 |
+
|
253 |
+
z, _, _ = self.transformer.forward(
|
254 |
+
x=z,
|
255 |
+
sequence_id=knn_sequence_id,
|
256 |
+
affine=affine,
|
257 |
+
affine_mask=knn_affine_mask,
|
258 |
+
chain_id=knn_chain_id,
|
259 |
+
)
|
260 |
+
|
261 |
+
# Unflatten the output and take the query node embedding, which will always be the first one because
|
262 |
+
# a node has distance 0 with itself and the KNN are sorted.
|
263 |
+
z = z.view(B, L, E, -1)
|
264 |
+
z = z[:, :, 0, :]
|
265 |
+
|
266 |
+
return z
|
267 |
+
|
268 |
+
@staticmethod
|
269 |
+
def find_knn_edges(
|
270 |
+
coords,
|
271 |
+
padding_mask,
|
272 |
+
coord_mask,
|
273 |
+
sequence_id: torch.Tensor | None = None,
|
274 |
+
knn: int | None = None,
|
275 |
+
) -> tuple:
|
276 |
+
assert knn is not None, "Must specify a non-null knn to find_knn_edges"
|
277 |
+
# Coords are N, CA, C
|
278 |
+
coords = coords.clone()
|
279 |
+
coords[~coord_mask] = 0
|
280 |
+
|
281 |
+
if sequence_id is None:
|
282 |
+
sequence_id = torch.zeros(
|
283 |
+
(coords.shape[0], coords.shape[1]), device=coords.device
|
284 |
+
).long()
|
285 |
+
|
286 |
+
with torch.no_grad(), torch.cuda.amp.autocast(enabled=False): # type: ignore
|
287 |
+
ca = coords[..., 1, :]
|
288 |
+
edges, edge_mask = knn_graph(
|
289 |
+
ca, coord_mask, padding_mask, sequence_id, no_knn=knn
|
290 |
+
)
|
291 |
+
|
292 |
+
return edges, edge_mask
|
293 |
+
|
294 |
+
def encode(
|
295 |
+
self,
|
296 |
+
coords: torch.Tensor,
|
297 |
+
attention_mask: torch.Tensor | None = None,
|
298 |
+
sequence_id: torch.Tensor | None = None,
|
299 |
+
residue_index: torch.Tensor | None = None,
|
300 |
+
):
|
301 |
+
coords = coords[..., :3, :]
|
302 |
+
affine, affine_mask = build_affine3d_from_coordinates(coords=coords)
|
303 |
+
|
304 |
+
if attention_mask is None:
|
305 |
+
attention_mask = torch.ones_like(affine_mask, dtype=torch.bool)
|
306 |
+
attention_mask = attention_mask.bool()
|
307 |
+
|
308 |
+
if sequence_id is None:
|
309 |
+
sequence_id = torch.zeros_like(affine_mask, dtype=torch.int64)
|
310 |
+
|
311 |
+
z = self.encode_local_structure(
|
312 |
+
coords=coords,
|
313 |
+
affine=affine,
|
314 |
+
attention_mask=attention_mask,
|
315 |
+
sequence_id=sequence_id,
|
316 |
+
affine_mask=affine_mask,
|
317 |
+
residue_index=residue_index,
|
318 |
+
)
|
319 |
+
|
320 |
+
z = z.masked_fill(~affine_mask.unsqueeze(2), 0)
|
321 |
+
z = self.pre_vq_proj(z)
|
322 |
+
|
323 |
+
z_q, min_encoding_indices, _ = self.codebook(z)
|
324 |
+
|
325 |
+
return z_q, min_encoding_indices
|
326 |
+
|
327 |
+
|
328 |
+
class StructureTokenDecoder(nn.Module):
|
329 |
+
def __init__(self, d_model, n_heads, n_layers):
|
330 |
+
super().__init__()
|
331 |
+
self.decoder_channels = d_model
|
332 |
+
|
333 |
+
self.vqvae_codebook_size = C.VQVAE_CODEBOOK_SIZE
|
334 |
+
self.special_tokens = C.VQVAE_SPECIAL_TOKENS
|
335 |
+
self.max_pae_bin = C.VQVAE_MAX_PAE_BIN
|
336 |
+
|
337 |
+
self.embed = nn.Embedding(
|
338 |
+
self.vqvae_codebook_size + len(self.special_tokens), d_model
|
339 |
+
)
|
340 |
+
self.decoder_stack = TransformerStack(
|
341 |
+
d_model, n_heads, 1, n_layers, scale_residue=False, n_layers_geom=0
|
342 |
+
)
|
343 |
+
|
344 |
+
self.affine_output_projection = Dim6RotStructureHead(
|
345 |
+
self.decoder_channels, 10, predict_torsion_angles=False
|
346 |
+
)
|
347 |
+
|
348 |
+
direction_loss_bins = C.VQVAE_DIRECTION_LOSS_BINS
|
349 |
+
pae_bins = C.VQVAE_PAE_BINS
|
350 |
+
self.pairwise_bins = [
|
351 |
+
64, # distogram
|
352 |
+
direction_loss_bins * 6, # direction bins
|
353 |
+
pae_bins, # predicted aligned error
|
354 |
+
]
|
355 |
+
self.pairwise_classification_head = PairwisePredictionHead(
|
356 |
+
self.decoder_channels,
|
357 |
+
downproject_dim=128,
|
358 |
+
hidden_dim=128,
|
359 |
+
n_bins=sum(self.pairwise_bins),
|
360 |
+
bias=False,
|
361 |
+
)
|
362 |
+
|
363 |
+
plddt_bins = C.VQVAE_PLDDT_BINS
|
364 |
+
self.plddt_head = RegressionHead(
|
365 |
+
embed_dim=self.decoder_channels, output_dim=plddt_bins
|
366 |
+
)
|
367 |
+
|
368 |
+
def decode(
|
369 |
+
self,
|
370 |
+
structure_tokens: torch.Tensor,
|
371 |
+
attention_mask: torch.Tensor | None = None,
|
372 |
+
sequence_id: torch.Tensor | None = None,
|
373 |
+
):
|
374 |
+
if attention_mask is None:
|
375 |
+
attention_mask = torch.ones_like(structure_tokens, dtype=torch.bool)
|
376 |
+
|
377 |
+
attention_mask = attention_mask.bool()
|
378 |
+
if sequence_id is None:
|
379 |
+
sequence_id = torch.zeros_like(structure_tokens, dtype=torch.int64)
|
380 |
+
# not supported for now
|
381 |
+
chain_id = torch.zeros_like(structure_tokens, dtype=torch.int64)
|
382 |
+
|
383 |
+
# check that BOS and EOS are set correctly
|
384 |
+
assert (
|
385 |
+
structure_tokens[:, 0].eq(self.special_tokens["BOS"]).all()
|
386 |
+
), "First token in structure_tokens must be BOS token"
|
387 |
+
assert (
|
388 |
+
structure_tokens[
|
389 |
+
torch.arange(structure_tokens.shape[0]), attention_mask.sum(1) - 1
|
390 |
+
]
|
391 |
+
.eq(self.special_tokens["EOS"])
|
392 |
+
.all()
|
393 |
+
), "Last token in structure_tokens must be EOS token"
|
394 |
+
assert (
|
395 |
+
(structure_tokens < 0).sum() == 0
|
396 |
+
), "All structure tokens set to -1 should be replaced with BOS, EOS, PAD, or MASK tokens by now, but that isn't the case!"
|
397 |
+
|
398 |
+
x = self.embed(structure_tokens)
|
399 |
+
# !!! NOTE: Attention mask is actually unused here so watch out
|
400 |
+
x, _, _ = self.decoder_stack.forward(
|
401 |
+
x, affine=None, affine_mask=None, sequence_id=sequence_id, chain_id=chain_id
|
402 |
+
)
|
403 |
+
|
404 |
+
tensor7_affine, bb_pred = self.affine_output_projection(
|
405 |
+
x, affine=None, affine_mask=torch.zeros_like(attention_mask)
|
406 |
+
)
|
407 |
+
|
408 |
+
pae, ptm = None, None
|
409 |
+
pairwise_logits = self.pairwise_classification_head(x)
|
410 |
+
_, _, pae_logits = [
|
411 |
+
(o if o.numel() > 0 else None)
|
412 |
+
for o in pairwise_logits.split(self.pairwise_bins, dim=-1)
|
413 |
+
]
|
414 |
+
|
415 |
+
special_tokens_mask = structure_tokens >= min(self.special_tokens.values())
|
416 |
+
pae = compute_predicted_aligned_error(
|
417 |
+
pae_logits, # type: ignore
|
418 |
+
aa_mask=~special_tokens_mask,
|
419 |
+
sequence_id=sequence_id,
|
420 |
+
max_bin=self.max_pae_bin,
|
421 |
+
)
|
422 |
+
# This might be broken for chainbreak tokens? We might align to the chainbreak
|
423 |
+
ptm = compute_tm(
|
424 |
+
pae_logits, # type: ignore
|
425 |
+
aa_mask=~special_tokens_mask,
|
426 |
+
max_bin=self.max_pae_bin,
|
427 |
+
)
|
428 |
+
|
429 |
+
plddt_logits = self.plddt_head(x)
|
430 |
+
plddt_value = CategoricalMixture(
|
431 |
+
plddt_logits, bins=plddt_logits.shape[-1]
|
432 |
+
).mean()
|
433 |
+
|
434 |
+
return dict(
|
435 |
+
tensor7_affine=tensor7_affine,
|
436 |
+
bb_pred=bb_pred,
|
437 |
+
plddt=plddt_value,
|
438 |
+
ptm=ptm,
|
439 |
+
predicted_aligned_error=pae,
|
440 |
+
)
|
Dyna-1/esm/pretrained.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from esm.models.esm3 import ESM3
|
7 |
+
from esm.models.esmc import ESMC
|
8 |
+
from esm.models.function_decoder import FunctionTokenDecoder
|
9 |
+
from esm.models.vqvae import (
|
10 |
+
StructureTokenDecoder,
|
11 |
+
StructureTokenEncoder,
|
12 |
+
)
|
13 |
+
from esm.tokenization import (
|
14 |
+
get_esm3_model_tokenizers,
|
15 |
+
get_esmc_model_tokenizers,
|
16 |
+
)
|
17 |
+
from esm.utils.constants.esm3 import data_root
|
18 |
+
from esm.utils.constants.models import (
|
19 |
+
ESM3_FUNCTION_DECODER_V0,
|
20 |
+
ESM3_OPEN_SMALL,
|
21 |
+
ESM3_STRUCTURE_DECODER_V0,
|
22 |
+
ESM3_STRUCTURE_ENCODER_V0,
|
23 |
+
ESMC_300M,
|
24 |
+
ESMC_600M,
|
25 |
+
)
|
26 |
+
|
27 |
+
ModelBuilder = Callable[[torch.device | str], nn.Module]
|
28 |
+
|
29 |
+
|
30 |
+
def ESM3_structure_encoder_v0(device: torch.device | str = "cpu"):
|
31 |
+
with torch.device(device):
|
32 |
+
model = StructureTokenEncoder(
|
33 |
+
d_model=1024, n_heads=1, v_heads=128, n_layers=2, d_out=128, n_codes=4096
|
34 |
+
).eval()
|
35 |
+
state_dict = torch.load(
|
36 |
+
data_root("esm3") / "data/weights/esm3_structure_encoder_v0.pth",
|
37 |
+
map_location=device,
|
38 |
+
)
|
39 |
+
model.load_state_dict(state_dict)
|
40 |
+
return model
|
41 |
+
|
42 |
+
|
43 |
+
def ESM3_structure_decoder_v0(device: torch.device | str = "cpu"):
|
44 |
+
with torch.device(device):
|
45 |
+
model = StructureTokenDecoder(d_model=1280, n_heads=20, n_layers=30).eval()
|
46 |
+
state_dict = torch.load(
|
47 |
+
data_root("esm3") / "data/weights/esm3_structure_decoder_v0.pth",
|
48 |
+
map_location=device,
|
49 |
+
)
|
50 |
+
model.load_state_dict(state_dict)
|
51 |
+
return model
|
52 |
+
|
53 |
+
|
54 |
+
def ESM3_function_decoder_v0(device: torch.device | str = "cpu"):
|
55 |
+
with torch.device(device):
|
56 |
+
model = FunctionTokenDecoder().eval()
|
57 |
+
state_dict = torch.load(
|
58 |
+
data_root("esm3") / "data/weights/esm3_function_decoder_v0.pth",
|
59 |
+
map_location=device,
|
60 |
+
)
|
61 |
+
model.load_state_dict(state_dict)
|
62 |
+
return model
|
63 |
+
|
64 |
+
|
65 |
+
def ESMC_300M_202412(device: torch.device | str = "cpu"):
|
66 |
+
with torch.device(device):
|
67 |
+
model = ESMC(
|
68 |
+
d_model=960, n_heads=15, n_layers=30, tokenizer=get_esmc_model_tokenizers()
|
69 |
+
).eval()
|
70 |
+
state_dict = torch.load(
|
71 |
+
data_root("esmc-300") / "data/weights/esmc_300m_2024_12_v0.pth",
|
72 |
+
map_location=device,
|
73 |
+
)
|
74 |
+
model.load_state_dict(state_dict)
|
75 |
+
|
76 |
+
return model
|
77 |
+
|
78 |
+
|
79 |
+
def ESMC_600M_202412(device: torch.device | str = "cpu"):
|
80 |
+
with torch.device(device):
|
81 |
+
model = ESMC(
|
82 |
+
d_model=1152, n_heads=18, n_layers=36, tokenizer=get_esmc_model_tokenizers()
|
83 |
+
).eval()
|
84 |
+
state_dict = torch.load(
|
85 |
+
data_root("esmc-600") / "data/weights/esmc_600m_2024_12_v0.pth",
|
86 |
+
map_location=device,
|
87 |
+
)
|
88 |
+
model.load_state_dict(state_dict)
|
89 |
+
|
90 |
+
return model
|
91 |
+
|
92 |
+
|
93 |
+
def ESM3_sm_open_v0(device: torch.device | str = "cpu"):
|
94 |
+
with torch.device(device):
|
95 |
+
model = ESM3(
|
96 |
+
d_model=1536,
|
97 |
+
n_heads=24,
|
98 |
+
v_heads=256,
|
99 |
+
n_layers=48,
|
100 |
+
structure_encoder_fn=ESM3_structure_encoder_v0,
|
101 |
+
structure_decoder_fn=ESM3_structure_decoder_v0,
|
102 |
+
function_decoder_fn=ESM3_function_decoder_v0,
|
103 |
+
tokenizers=get_esm3_model_tokenizers(ESM3_OPEN_SMALL),
|
104 |
+
).eval()
|
105 |
+
state_dict = torch.load(
|
106 |
+
data_root("esm3") / "data/weights/esm3_sm_open_v1.pth", map_location=device
|
107 |
+
)
|
108 |
+
model.load_state_dict(state_dict)
|
109 |
+
return model
|
110 |
+
|
111 |
+
|
112 |
+
LOCAL_MODEL_REGISTRY: dict[str, ModelBuilder] = {
|
113 |
+
ESM3_OPEN_SMALL: ESM3_sm_open_v0,
|
114 |
+
ESM3_STRUCTURE_ENCODER_V0: ESM3_structure_encoder_v0,
|
115 |
+
ESM3_STRUCTURE_DECODER_V0: ESM3_structure_decoder_v0,
|
116 |
+
ESM3_FUNCTION_DECODER_V0: ESM3_function_decoder_v0,
|
117 |
+
ESMC_600M: ESMC_600M_202412,
|
118 |
+
ESMC_300M: ESMC_300M_202412,
|
119 |
+
}
|
120 |
+
|
121 |
+
|
122 |
+
def load_local_model(
|
123 |
+
model_name: str, device: torch.device = torch.device("cpu")
|
124 |
+
) -> nn.Module:
|
125 |
+
if model_name not in LOCAL_MODEL_REGISTRY:
|
126 |
+
raise ValueError(f"Model {model_name} not found in local model registry.")
|
127 |
+
return LOCAL_MODEL_REGISTRY[model_name](device)
|
128 |
+
|
129 |
+
|
130 |
+
# Register custom versions of ESM3 for use with the local inference API
|
131 |
+
def register_local_model(model_name: str, model_builder: ModelBuilder) -> None:
|
132 |
+
LOCAL_MODEL_REGISTRY[model_name] = model_builder
|
Dyna-1/esm/sdk/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from esm.sdk.forge import ESM3ForgeInferenceClient
|
4 |
+
|
5 |
+
# Note: please do not import ESM3SageMakerClient here since that requires AWS SDK.
|
6 |
+
|
7 |
+
|
8 |
+
def client(
|
9 |
+
model="esm3-sm-open-v1",
|
10 |
+
url="https://forge.evolutionaryscale.ai",
|
11 |
+
token=os.environ.get("ESM_API_KEY", ""),
|
12 |
+
request_timeout=None,
|
13 |
+
):
|
14 |
+
"""
|
15 |
+
Args:
|
16 |
+
model: Name of the model to use.
|
17 |
+
url: URL of a forge server.
|
18 |
+
token: User's API token.
|
19 |
+
request_timeout: Amount of time to wait for a request to finish.
|
20 |
+
Default is wait indefinitely.
|
21 |
+
"""
|
22 |
+
return ESM3ForgeInferenceClient(model, url, token, request_timeout)
|
Dyna-1/esm/sdk/api.py
ADDED
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from abc import ABC
|
4 |
+
from typing import List, Sequence
|
5 |
+
|
6 |
+
import attr
|
7 |
+
import torch
|
8 |
+
from attr import asdict, define
|
9 |
+
|
10 |
+
import esm.utils.constants.api as C
|
11 |
+
from esm.tokenization import (
|
12 |
+
TokenizerCollectionProtocol,
|
13 |
+
get_esm3_model_tokenizers,
|
14 |
+
)
|
15 |
+
from esm.utils import encoding
|
16 |
+
from esm.utils.constants.models import ESM3_OPEN_SMALL
|
17 |
+
from esm.utils.misc import (
|
18 |
+
get_chainbreak_boundaries_from_sequence,
|
19 |
+
)
|
20 |
+
from esm.utils.structure.protein_chain import ProteinChain
|
21 |
+
from esm.utils.structure.protein_complex import ProteinComplex
|
22 |
+
from esm.utils.types import FunctionAnnotation, PathOrBuffer
|
23 |
+
|
24 |
+
|
25 |
+
class ProteinType(ABC): ...
|
26 |
+
|
27 |
+
|
28 |
+
## Basic Types
|
29 |
+
@define
|
30 |
+
class ESMProtein(ProteinType):
|
31 |
+
# Tracks
|
32 |
+
sequence: str | None = None
|
33 |
+
secondary_structure: str | None = None
|
34 |
+
sasa: list[float | None] | None = None
|
35 |
+
function_annotations: list[FunctionAnnotation] | None = None
|
36 |
+
coordinates: torch.Tensor | None = None
|
37 |
+
|
38 |
+
# Metrics
|
39 |
+
plddt: torch.Tensor | None = None
|
40 |
+
ptm: torch.Tensor | None = None
|
41 |
+
|
42 |
+
|
43 |
+
# When calling EvolutionaryScale API, use this flag to disclose any
|
44 |
+
# sequences that may potentially have concerns.
|
45 |
+
# Such sequences may not go through standard safety filter for approved users.
|
46 |
+
# Reach out if interested in using this.
|
47 |
+
potential_sequence_of_concern: bool = False
|
48 |
+
|
49 |
+
def __len__(self):
|
50 |
+
if self.sequence is not None:
|
51 |
+
return len(self.sequence)
|
52 |
+
elif self.secondary_structure is not None:
|
53 |
+
return len(self.secondary_structure)
|
54 |
+
elif self.sasa is not None:
|
55 |
+
return len(self.sasa)
|
56 |
+
elif self.coordinates is not None:
|
57 |
+
return self.coordinates.size(0)
|
58 |
+
else:
|
59 |
+
raise ValueError("No track to determine length from.")
|
60 |
+
|
61 |
+
@classmethod
|
62 |
+
def from_pdb(
|
63 |
+
cls,
|
64 |
+
path: PathOrBuffer,
|
65 |
+
chain_id: str = "detect",
|
66 |
+
id: str | None = None,
|
67 |
+
is_predicted: bool = False,
|
68 |
+
) -> ESMProtein:
|
69 |
+
protein_chain = ProteinChain.from_pdb(
|
70 |
+
path=path, chain_id=chain_id, id=id, is_predicted=is_predicted
|
71 |
+
)
|
72 |
+
return cls.from_protein_chain(protein_chain)
|
73 |
+
|
74 |
+
@classmethod
|
75 |
+
def from_protein_chain(
|
76 |
+
cls, protein_chain: ProteinChain, with_annotations: bool = False
|
77 |
+
) -> ESMProtein:
|
78 |
+
# By default, we don't annotate with DSSP / SASA, which are expensive.
|
79 |
+
# If mkdssp is installed, we can annotate with a flag.
|
80 |
+
if with_annotations:
|
81 |
+
return ESMProtein(
|
82 |
+
sequence=protein_chain.sequence,
|
83 |
+
secondary_structure=protein_chain.dssp().tolist(),
|
84 |
+
sasa=protein_chain.sasa().tolist(),
|
85 |
+
function_annotations=None,
|
86 |
+
coordinates=torch.tensor(protein_chain.atom37_positions),
|
87 |
+
)
|
88 |
+
else:
|
89 |
+
return ESMProtein(
|
90 |
+
sequence=protein_chain.sequence,
|
91 |
+
secondary_structure=None,
|
92 |
+
sasa=None,
|
93 |
+
function_annotations=None,
|
94 |
+
coordinates=torch.tensor(protein_chain.atom37_positions),
|
95 |
+
)
|
96 |
+
|
97 |
+
@classmethod
|
98 |
+
def from_protein_complex(
|
99 |
+
cls, protein_complex: ProteinComplex, with_annotations: bool = False
|
100 |
+
) -> ESMProtein:
|
101 |
+
if with_annotations:
|
102 |
+
raise NotImplementedError(
|
103 |
+
"Annotations are not supported for ProteinComplex yet."
|
104 |
+
)
|
105 |
+
|
106 |
+
return ESMProtein(
|
107 |
+
sequence=protein_complex.sequence,
|
108 |
+
secondary_structure=None,
|
109 |
+
sasa=None,
|
110 |
+
function_annotations=None,
|
111 |
+
coordinates=torch.tensor(protein_complex.atom37_positions),
|
112 |
+
)
|
113 |
+
|
114 |
+
def to_pdb(self, pdb_path: PathOrBuffer) -> None:
|
115 |
+
# Note: Will work for single chains as well and produce same pdb file
|
116 |
+
protein_complex = self.to_protein_complex().infer_oxygen()
|
117 |
+
protein_complex.to_pdb(pdb_path)
|
118 |
+
|
119 |
+
def to_pdb_string(self) -> str:
|
120 |
+
protein_chain = self.to_protein_chain()
|
121 |
+
return protein_chain.to_pdb_string()
|
122 |
+
|
123 |
+
def to_protein_chain(self) -> ProteinChain:
|
124 |
+
if self.coordinates is None:
|
125 |
+
raise ValueError("Coordinates are required to convert to a ProteinChain.")
|
126 |
+
protein_chain = ProteinChain.from_atom37(
|
127 |
+
atom37_positions=self.coordinates.to("cpu").numpy(),
|
128 |
+
id=None,
|
129 |
+
sequence=None if self.sequence is None else self.sequence.replace("_", "X"),
|
130 |
+
chain_id=None,
|
131 |
+
entity_id=None,
|
132 |
+
residue_index=None,
|
133 |
+
insertion_code=None,
|
134 |
+
confidence=None
|
135 |
+
if self.plddt is None
|
136 |
+
else self.plddt.detach().cpu().numpy(),
|
137 |
+
)
|
138 |
+
return protein_chain
|
139 |
+
|
140 |
+
def to_protein_complex(
|
141 |
+
self, copy_annotations_from_ground_truth: ProteinComplex | None = None
|
142 |
+
) -> ProteinComplex:
|
143 |
+
assert (
|
144 |
+
self.sequence is not None
|
145 |
+
), "ESMProtein must have a sequence to convert to ProteinComplex"
|
146 |
+
assert (
|
147 |
+
self.coordinates is not None
|
148 |
+
), "ESMProtein must have coordinates to convert to ProteinComplex"
|
149 |
+
coords = self.coordinates.to("cpu").numpy()
|
150 |
+
|
151 |
+
chain_boundaries = get_chainbreak_boundaries_from_sequence(self.sequence)
|
152 |
+
if copy_annotations_from_ground_truth is not None:
|
153 |
+
gt_chains = list(copy_annotations_from_ground_truth.chain_iter())
|
154 |
+
else:
|
155 |
+
gt_chains = None
|
156 |
+
pred_chains = []
|
157 |
+
for i, (start, end) in enumerate(chain_boundaries):
|
158 |
+
pred_chain = ProteinChain.from_atom37(
|
159 |
+
atom37_positions=coords[start:end],
|
160 |
+
sequence=self.sequence[start:end],
|
161 |
+
chain_id=gt_chains[i].chain_id if gt_chains is not None else None,
|
162 |
+
entity_id=gt_chains[i].entity_id if gt_chains is not None else None,
|
163 |
+
)
|
164 |
+
pred_chains.append(pred_chain)
|
165 |
+
return ProteinComplex.from_chains(pred_chains)
|
166 |
+
|
167 |
+
|
168 |
+
@define
|
169 |
+
class ESMProteinTensor(ProteinType):
|
170 |
+
sequence: torch.Tensor | None = None
|
171 |
+
structure: torch.Tensor | None = None
|
172 |
+
secondary_structure: torch.Tensor | None = None
|
173 |
+
sasa: torch.Tensor | None = None
|
174 |
+
function: torch.Tensor | None = None
|
175 |
+
residue_annotations: torch.Tensor | None = None
|
176 |
+
coordinates: torch.Tensor | None = None
|
177 |
+
|
178 |
+
# When calling EvolutionaryScale API, use this flag to disclose any
|
179 |
+
# sequences that may potentially have concerns.
|
180 |
+
# Such sequences may not go through standard safety filter for approved users.
|
181 |
+
# Reach out if interested in using this.
|
182 |
+
potential_sequence_of_concern: bool = False
|
183 |
+
# Control vectors are vectors added to each layer of the model to nudge hidden states to the desired direction.
|
184 |
+
# len(control_vectors) == number of blocks in the model. Each vector in the list have the shape of (batch size, sequence length, hidden dim)
|
185 |
+
# so it can be added to the corresponding layer in the model
|
186 |
+
|
187 |
+
def _detect_attribute(self, func, msg):
|
188 |
+
mapped = {
|
189 |
+
k: func(k, v)
|
190 |
+
for k, v in asdict(self).items()
|
191 |
+
if isinstance(v, torch.Tensor)
|
192 |
+
}
|
193 |
+
s = set(mapped.values())
|
194 |
+
if len(s) <= 0:
|
195 |
+
return None
|
196 |
+
if len(s) != 1:
|
197 |
+
raise ValueError(f"Either no tracks or inconsistent {msg}: {mapped}")
|
198 |
+
return next(iter(s))
|
199 |
+
|
200 |
+
def __len__(self) -> int:
|
201 |
+
l = self._detect_attribute(lambda _, x: x.size(0), "length")
|
202 |
+
return l if l is not None else 0
|
203 |
+
|
204 |
+
@property
|
205 |
+
def device(self) -> str | torch.device:
|
206 |
+
d = self._detect_attribute(lambda _, x: x.device, "device")
|
207 |
+
assert d is not None
|
208 |
+
return d
|
209 |
+
|
210 |
+
def to(self, device_or_dtype: str | torch.device | torch.dtype) -> ESMProteinTensor:
|
211 |
+
def _to(name):
|
212 |
+
v = getattr(self, name)
|
213 |
+
if v is not None and isinstance(v, torch.Tensor):
|
214 |
+
setattr(self, name, v.to(device_or_dtype))
|
215 |
+
|
216 |
+
for n in attr.fields(ESMProteinTensor):
|
217 |
+
_to(n.name)
|
218 |
+
|
219 |
+
return self
|
220 |
+
|
221 |
+
@classmethod
|
222 |
+
def empty(
|
223 |
+
cls,
|
224 |
+
length: int,
|
225 |
+
tokenizers: TokenizerCollectionProtocol | None = None,
|
226 |
+
device: torch.device | str = "cpu",
|
227 |
+
) -> ESMProteinTensor:
|
228 |
+
if tokenizers is None:
|
229 |
+
tokenizers = get_esm3_model_tokenizers(ESM3_OPEN_SMALL)
|
230 |
+
|
231 |
+
return ESMProteinTensor(
|
232 |
+
sequence=encoding.get_default_sequence_tokens(
|
233 |
+
length, tokenizers.sequence
|
234 |
+
).to(device),
|
235 |
+
structure=encoding.get_default_structure_tokens(
|
236 |
+
length, tokenizers.structure
|
237 |
+
).to(device),
|
238 |
+
secondary_structure=encoding.get_default_secondary_structure_tokens(
|
239 |
+
length, tokenizers.secondary_structure
|
240 |
+
).to(device),
|
241 |
+
sasa=encoding.get_default_sasa_tokens(length, tokenizers.sasa).to(device),
|
242 |
+
function=encoding.get_default_function_tokens(
|
243 |
+
length, tokenizers.function
|
244 |
+
).to(device),
|
245 |
+
residue_annotations=encoding.get_default_residue_annotation_tokens(
|
246 |
+
length, tokenizers.residue_annotations
|
247 |
+
).to(device),
|
248 |
+
)
|
249 |
+
|
250 |
+
|
251 |
+
@define
|
252 |
+
class ESMProteinError(Exception, ProteinType):
|
253 |
+
error_code: int # Error code follows HTTP convention, i.e., 404 NotFoundError, 500 InternalError.
|
254 |
+
error_msg: str
|
255 |
+
|
256 |
+
|
257 |
+
## High Level Endpoint Types
|
258 |
+
@define
|
259 |
+
class GenerationConfig:
|
260 |
+
track: str = ""
|
261 |
+
invalid_ids: Sequence[int] = []
|
262 |
+
# Controls the number of tokens to unmask during each round of iterative generation.
|
263 |
+
schedule: str = attr.field(
|
264 |
+
validator=attr.validators.in_(["cosine", "linear"]), default="cosine"
|
265 |
+
)
|
266 |
+
# Controls which tokens to unmask during each round of iterative generation.
|
267 |
+
# "random" will unmask a correct number of tokens randomly.
|
268 |
+
# "entropy" will unmask the tokens with the lowest logit entropy first.
|
269 |
+
strategy: str = attr.field(
|
270 |
+
validator=attr.validators.in_(["random", "entropy"]), default="entropy"
|
271 |
+
)
|
272 |
+
# Set this to a higher value for better generation results.
|
273 |
+
# Note that this needs to be less than or equal to the sequence length.
|
274 |
+
num_steps: int = 1
|
275 |
+
temperature: float = 1.0
|
276 |
+
temperature_annealing: bool = False
|
277 |
+
top_p: float = 1.0
|
278 |
+
condition_on_coordinates_only: bool = True
|
279 |
+
|
280 |
+
def use_entropy_based_unmasking_strategy(self):
|
281 |
+
"""Use entropy based unmasking strategy during generation."""
|
282 |
+
self.schedule = "cosine"
|
283 |
+
self.strategy = "entropy"
|
284 |
+
self.temperature_annealing = False
|
285 |
+
|
286 |
+
def use_generative_unmasking_strategy(self):
|
287 |
+
"""Use an unmasking strategy that produces more variety of generations."""
|
288 |
+
self.schedule = "cosine"
|
289 |
+
self.strategy = "random"
|
290 |
+
self.temperature_annealing = True
|
291 |
+
|
292 |
+
|
293 |
+
@define
|
294 |
+
class InverseFoldingConfig:
|
295 |
+
invalid_ids: Sequence[int] = []
|
296 |
+
temperature: float = 1.0
|
297 |
+
|
298 |
+
|
299 |
+
## Low Level Endpoint Types
|
300 |
+
@define
|
301 |
+
class SamplingTrackConfig:
|
302 |
+
temperature: float = 1.0
|
303 |
+
top_p: float = 1.0
|
304 |
+
only_sample_masked_tokens: bool = True
|
305 |
+
invalid_ids: Sequence[int] = []
|
306 |
+
topk_logprobs: int = 0
|
307 |
+
|
308 |
+
|
309 |
+
@define
|
310 |
+
class SamplingConfig:
|
311 |
+
sequence: SamplingTrackConfig | None = attr.field(
|
312 |
+
default=None, metadata={"max_topk": C.MAX_TOPK_SEQUENCE}
|
313 |
+
)
|
314 |
+
structure: SamplingTrackConfig | None = attr.field(
|
315 |
+
default=None, metadata={"max_topk": C.MAX_TOPK_STRUCTURE}
|
316 |
+
)
|
317 |
+
secondary_structure: SamplingTrackConfig | None = attr.field(
|
318 |
+
default=None, metadata={"max_topk": C.MAX_TOPK_SECONDARY_STRUCTURE}
|
319 |
+
)
|
320 |
+
sasa: SamplingTrackConfig | None = attr.field(
|
321 |
+
default=None, metadata={"max_topk": C.MAX_TOPK_SASA}
|
322 |
+
)
|
323 |
+
function: SamplingTrackConfig | None = attr.field(
|
324 |
+
default=None, metadata={"max_topk": C.MAX_TOPK_FUNCTION}
|
325 |
+
)
|
326 |
+
|
327 |
+
return_per_residue_embeddings: bool = False
|
328 |
+
return_mean_embedding: bool = False
|
329 |
+
|
330 |
+
|
331 |
+
@define
|
332 |
+
class ForwardTrackData:
|
333 |
+
sequence: torch.Tensor | None = None
|
334 |
+
structure: torch.Tensor | None = None
|
335 |
+
secondary_structure: torch.Tensor | None = None
|
336 |
+
sasa: torch.Tensor | None = None
|
337 |
+
function: torch.Tensor | None = None
|
338 |
+
|
339 |
+
|
340 |
+
@define
|
341 |
+
class LogitsConfig:
|
342 |
+
# Logits.
|
343 |
+
sequence: bool = False
|
344 |
+
structure: bool = False
|
345 |
+
secondary_structure: bool = False
|
346 |
+
sasa: bool = False
|
347 |
+
function: bool = False
|
348 |
+
residue_annotations: bool = False
|
349 |
+
|
350 |
+
# Embeddings.
|
351 |
+
return_embeddings: bool = False
|
352 |
+
|
353 |
+
|
354 |
+
@define
|
355 |
+
class LogitsOutput:
|
356 |
+
logits: ForwardTrackData | None = None
|
357 |
+
embeddings: torch.Tensor | None = None
|
358 |
+
|
359 |
+
# Residue annotations is multi-hot, so deserves special treatment
|
360 |
+
# It's not a categorical distribution, but instead a bernoulli, so
|
361 |
+
# softmax across the last dimension is _wrong_
|
362 |
+
residue_annotation_logits: torch.Tensor | None = None
|
363 |
+
|
364 |
+
|
365 |
+
@define
|
366 |
+
class ForwardAndSampleOutput(LogitsOutput):
|
367 |
+
protein_tensor: ESMProteinTensor = ESMProteinTensor()
|
368 |
+
|
369 |
+
entropy: ForwardTrackData | None = None
|
370 |
+
# Probability of sampled token
|
371 |
+
prob: ForwardTrackData | None = None
|
372 |
+
logprob: ForwardTrackData | None = None
|
373 |
+
# Top probability at this position
|
374 |
+
top_prob: ForwardTrackData | None = None
|
375 |
+
topk_logprob: ForwardTrackData | None = None
|
376 |
+
# Which tokens correspond to top probability
|
377 |
+
topk_tokens: ForwardTrackData | None = None
|
378 |
+
per_residue_embedding: torch.Tensor | None = None
|
379 |
+
mean_embedding: torch.Tensor | None = None
|
380 |
+
|
381 |
+
|
382 |
+
class ESM3InferenceClient(ABC):
|
383 |
+
def generate(self, input: ProteinType, config: GenerationConfig) -> ProteinType:
|
384 |
+
# This is the easiest and most flexible way to run ESM3. Generate will
|
385 |
+
# iteratively sample tokens an provide an output with the track specified
|
386 |
+
# completely filled out, according to the GenerationConfig provided.
|
387 |
+
# It is a local function wrapping calls for encode -> iterative_sampling -> decode.
|
388 |
+
# if a ESMProteinTensor is provided, encode and decode are skipped
|
389 |
+
raise NotImplementedError
|
390 |
+
|
391 |
+
def batch_generate(
|
392 |
+
self, inputs: Sequence[ProteinType], configs: Sequence[GenerationConfig]
|
393 |
+
) -> Sequence[ProteinType]:
|
394 |
+
# Same as generate(...), but generates a batch of proteins at once.
|
395 |
+
raise NotImplementedError
|
396 |
+
|
397 |
+
def encode(self, input: ESMProtein) -> ESMProteinTensor:
|
398 |
+
# Encode allows for encoding RawRepresentation into TokenizedRepresentation.
|
399 |
+
# This runs the structure_token_encoder, as well as dealing with PDB => atom37 conversion
|
400 |
+
raise NotImplementedError
|
401 |
+
|
402 |
+
def decode(self, input: ESMProteinTensor) -> ESMProtein:
|
403 |
+
# Decode is the inverse of encode, and runs a structure_token_decoder to output coordinates
|
404 |
+
raise NotImplementedError
|
405 |
+
|
406 |
+
def logits(
|
407 |
+
self, input: ESMProteinTensor, config: LogitsConfig = LogitsConfig()
|
408 |
+
) -> LogitsOutput:
|
409 |
+
# Our API generally discourages using raw forwards.
|
410 |
+
# This is because sending logits can be prohibitively expensive.
|
411 |
+
# Please use forward_and_sample instead.
|
412 |
+
raise NotImplementedError
|
413 |
+
|
414 |
+
def forward_and_sample(
|
415 |
+
self, input: ESMProteinTensor, sampling_configuration: SamplingConfig
|
416 |
+
) -> ForwardAndSampleOutput:
|
417 |
+
# forward_and_sample runs a single model forward, sampling tokens according to `SamplingConfiguration`.
|
418 |
+
# This is the way for power users to run ESM3. We hope to design this in a way to enable high throughput
|
419 |
+
# inference, as well as arbitrary chain-of-though invocations of ESM3.
|
420 |
+
raise NotImplementedError
|
421 |
+
|
422 |
+
@property
|
423 |
+
def raw_model(self):
|
424 |
+
# Get underlying esm3 model of an inference client.
|
425 |
+
raise NotImplementedError
|
426 |
+
|
427 |
+
|
428 |
+
class ESMCInferenceClient(ABC):
|
429 |
+
def encode(self, input: ESMProtein) -> ESMProteinTensor:
|
430 |
+
# Encode allows for encoding RawRepresentation into TokenizedRepresentation.
|
431 |
+
raise NotImplementedError
|
432 |
+
|
433 |
+
def decode(self, input: ESMProteinTensor) -> ESMProtein:
|
434 |
+
# Decode is the inverse of encode
|
435 |
+
raise NotImplementedError
|
436 |
+
|
437 |
+
def logits(
|
438 |
+
self, input: ESMProteinTensor, config: LogitsConfig = LogitsConfig()
|
439 |
+
) -> LogitsOutput:
|
440 |
+
raise NotImplementedError
|
441 |
+
|
442 |
+
@property
|
443 |
+
def raw_model(self):
|
444 |
+
# Get underlying esmc model of an inference client.
|
445 |
+
raise NotImplementedError
|
Dyna-1/esm/sdk/forge.py
ADDED
@@ -0,0 +1,580 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
from functools import wraps
|
3 |
+
from typing import Sequence
|
4 |
+
from urllib.parse import urljoin
|
5 |
+
|
6 |
+
import requests
|
7 |
+
import torch
|
8 |
+
from tenacity import retry, retry_if_result, stop_after_attempt, wait_exponential
|
9 |
+
|
10 |
+
from esm.sdk.api import (
|
11 |
+
ESM3InferenceClient,
|
12 |
+
ESMProtein,
|
13 |
+
ESMProteinError,
|
14 |
+
ESMProteinTensor,
|
15 |
+
ForwardAndSampleOutput,
|
16 |
+
ForwardTrackData,
|
17 |
+
GenerationConfig,
|
18 |
+
InverseFoldingConfig,
|
19 |
+
LogitsConfig,
|
20 |
+
LogitsOutput,
|
21 |
+
ProteinType,
|
22 |
+
SamplingConfig,
|
23 |
+
SamplingTrackConfig,
|
24 |
+
)
|
25 |
+
from esm.utils.misc import maybe_list, maybe_tensor
|
26 |
+
from esm.utils.sampling import validate_sampling_config
|
27 |
+
from esm.utils.types import FunctionAnnotation
|
28 |
+
|
29 |
+
|
30 |
+
def _list_to_function_annotations(l) -> list[FunctionAnnotation] | None:
|
31 |
+
if l is None or len(l) <= 0:
|
32 |
+
return None
|
33 |
+
return [FunctionAnnotation(*t) for t in l]
|
34 |
+
|
35 |
+
|
36 |
+
def retry_if_specific_error(exception):
|
37 |
+
"""
|
38 |
+
We only retry on specific errors.
|
39 |
+
Currently we retry for 502 (bad gateway) and 429 (rate limit)
|
40 |
+
"""
|
41 |
+
return isinstance(exception, ESMProteinError) and exception.error_code in {
|
42 |
+
429,
|
43 |
+
502,
|
44 |
+
504,
|
45 |
+
}
|
46 |
+
|
47 |
+
|
48 |
+
def log_retry_attempt(retry_state):
|
49 |
+
print(
|
50 |
+
f"Retrying... Attempt {retry_state.attempt_number} after {retry_state.next_action.sleep}s due to: {retry_state.outcome.result()}"
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
def _validate_protein_tensor_input(input):
|
55 |
+
if not isinstance(input, ESMProteinTensor):
|
56 |
+
raise ValueError(
|
57 |
+
"Input must be an ESMProteinTensor instance. "
|
58 |
+
"Use encode() API to encode an ESMProtein into ESMProteinTensor."
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
class SequenceStructureForgeInferenceClient:
|
63 |
+
def __init__(
|
64 |
+
self,
|
65 |
+
url: str = "https://forge.evolutionaryscale.ai",
|
66 |
+
token: str = "",
|
67 |
+
request_timeout: int | None = None,
|
68 |
+
):
|
69 |
+
if token == "":
|
70 |
+
raise RuntimeError(
|
71 |
+
"Please provide a token to connect to Forge via token=YOUR_API_TOKEN_HERE"
|
72 |
+
)
|
73 |
+
self.url = url
|
74 |
+
self.token = token
|
75 |
+
self.headers = {"Authorization": f"Bearer {self.token}"}
|
76 |
+
self.request_timeout = request_timeout
|
77 |
+
|
78 |
+
def fold(
|
79 |
+
self,
|
80 |
+
sequence: str,
|
81 |
+
potential_sequence_of_concern: bool,
|
82 |
+
model_name: str | None = None,
|
83 |
+
) -> ESMProtein | ESMProteinError:
|
84 |
+
request = {"sequence": sequence}
|
85 |
+
if model_name is not None:
|
86 |
+
request["model"] = model_name
|
87 |
+
try:
|
88 |
+
data = self._post("fold", request, potential_sequence_of_concern)
|
89 |
+
except ESMProteinError as e:
|
90 |
+
return e
|
91 |
+
|
92 |
+
return ESMProtein(
|
93 |
+
coordinates=maybe_tensor(data["coordinates"], convert_none_to_nan=True)
|
94 |
+
)
|
95 |
+
|
96 |
+
def inverse_fold(
|
97 |
+
self,
|
98 |
+
coordinates: torch.Tensor,
|
99 |
+
config: InverseFoldingConfig,
|
100 |
+
potential_sequence_of_concern: bool,
|
101 |
+
model_name: str | None = None,
|
102 |
+
) -> ESMProtein | ESMProteinError:
|
103 |
+
inverse_folding_config = {
|
104 |
+
"invalid_ids": config.invalid_ids,
|
105 |
+
"temperature": config.temperature,
|
106 |
+
}
|
107 |
+
request = {
|
108 |
+
"coordinates": maybe_list(coordinates, convert_nan_to_none=True),
|
109 |
+
"inverse_folding_config": inverse_folding_config,
|
110 |
+
}
|
111 |
+
if model_name is not None:
|
112 |
+
request["model"] = model_name
|
113 |
+
try:
|
114 |
+
data = self._post("inverse_fold", request, potential_sequence_of_concern)
|
115 |
+
except ESMProteinError as e:
|
116 |
+
return e
|
117 |
+
|
118 |
+
return ESMProtein(sequence=data["sequence"])
|
119 |
+
|
120 |
+
def _post(self, endpoint, request, potential_sequence_of_concern):
|
121 |
+
request["potential_sequence_of_concern"] = potential_sequence_of_concern
|
122 |
+
|
123 |
+
response = requests.post(
|
124 |
+
urljoin(self.url, f"/api/v1/{endpoint}"),
|
125 |
+
json=request,
|
126 |
+
headers=self.headers,
|
127 |
+
timeout=self.request_timeout,
|
128 |
+
)
|
129 |
+
|
130 |
+
if not response.ok:
|
131 |
+
raise ESMProteinError(
|
132 |
+
error_code=response.status_code,
|
133 |
+
error_msg=f"Failure in {endpoint}: {response.text}",
|
134 |
+
)
|
135 |
+
|
136 |
+
data = response.json()
|
137 |
+
# Nextjs puts outputs dict under "data" key.
|
138 |
+
# Lift it up for easier downstream processing.
|
139 |
+
if "outputs" not in data and "data" in data:
|
140 |
+
data = data["data"]
|
141 |
+
|
142 |
+
# Print warning message if there is any.
|
143 |
+
if "warning_messages" in data and data["warning_messages"] is not None:
|
144 |
+
for msg in data["warning_messages"]:
|
145 |
+
print("\033[31m", msg, "\033[0m")
|
146 |
+
|
147 |
+
return data
|
148 |
+
|
149 |
+
|
150 |
+
class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
151 |
+
def __init__(
|
152 |
+
self,
|
153 |
+
model: str,
|
154 |
+
url: str = "https://forge.evolutionaryscale.ai",
|
155 |
+
token: str = "",
|
156 |
+
request_timeout: int | None = None,
|
157 |
+
min_retry_wait: int = 1,
|
158 |
+
max_retry_wait: int = 10,
|
159 |
+
max_retry_attempts: int = 5,
|
160 |
+
):
|
161 |
+
if token == "":
|
162 |
+
raise RuntimeError(
|
163 |
+
"Please provide a token to connect to Forge via token=YOUR_API_TOKEN_HERE"
|
164 |
+
)
|
165 |
+
self.model = model # Name of the model to run.
|
166 |
+
self.url = url
|
167 |
+
self.token = token
|
168 |
+
self.headers = {"Authorization": f"Bearer {self.token}"}
|
169 |
+
self.request_timeout = request_timeout
|
170 |
+
self.min_retry_wait = min_retry_wait
|
171 |
+
self.max_retry_wait = max_retry_wait
|
172 |
+
self.max_retry_attempts = max_retry_attempts
|
173 |
+
|
174 |
+
@staticmethod
|
175 |
+
def retry_decorator(func):
|
176 |
+
"""
|
177 |
+
A static method that returns a retry decorator. This decorator uses the
|
178 |
+
instance's retry settings.
|
179 |
+
"""
|
180 |
+
|
181 |
+
@wraps(func)
|
182 |
+
def wrapper(instance, *args, **kwargs):
|
183 |
+
retry_decorator = retry(
|
184 |
+
retry=retry_if_result(retry_if_specific_error),
|
185 |
+
wait=wait_exponential(
|
186 |
+
multiplier=1,
|
187 |
+
min=instance.min_retry_wait,
|
188 |
+
max=instance.max_retry_wait,
|
189 |
+
),
|
190 |
+
stop=stop_after_attempt(instance.max_retry_attempts),
|
191 |
+
before_sleep=log_retry_attempt,
|
192 |
+
)
|
193 |
+
# Apply the retry decorator to the function
|
194 |
+
return retry_decorator(func)(instance, *args, **kwargs)
|
195 |
+
|
196 |
+
return wrapper
|
197 |
+
|
198 |
+
@retry_decorator
|
199 |
+
def generate(self, input: ProteinType, config: GenerationConfig) -> ProteinType:
|
200 |
+
if isinstance(input, ESMProtein):
|
201 |
+
output = self.__generate_protein(input, config)
|
202 |
+
elif isinstance(input, ESMProteinTensor):
|
203 |
+
output = self.__generate_protein_tensor(input, config)
|
204 |
+
else:
|
205 |
+
return ESMProteinError(
|
206 |
+
error_code=500, error_msg=f"Unknown input type {type(input)}"
|
207 |
+
)
|
208 |
+
|
209 |
+
if (
|
210 |
+
isinstance(output, ESMProtein)
|
211 |
+
and isinstance(input, ESMProtein)
|
212 |
+
and config.track not in ["function", "residue_annotations"]
|
213 |
+
):
|
214 |
+
# Function and residue annotation encoding/decoding is lossy
|
215 |
+
# There is no guarantee that decoding encoded tokens will yield the same input
|
216 |
+
output.function_annotations = input.function_annotations
|
217 |
+
|
218 |
+
return output
|
219 |
+
|
220 |
+
def batch_generate(
|
221 |
+
self, inputs: Sequence[ProteinType], configs: Sequence[GenerationConfig]
|
222 |
+
) -> Sequence[ProteinType]:
|
223 |
+
"""Forge supports auto-batching. So batch_generate() for the Forge client
|
224 |
+
is as simple as running a collection of generate() in parallel using asyncio.
|
225 |
+
"""
|
226 |
+
loop = asyncio.get_event_loop()
|
227 |
+
|
228 |
+
async def _async_generate():
|
229 |
+
futures = [
|
230 |
+
loop.run_in_executor(None, self.generate, protein, config)
|
231 |
+
for protein, config in zip(inputs, configs)
|
232 |
+
]
|
233 |
+
return await asyncio.gather(*futures, return_exceptions=True)
|
234 |
+
|
235 |
+
results = loop.run_until_complete(_async_generate())
|
236 |
+
|
237 |
+
def _capture_exception(r):
|
238 |
+
if isinstance(r, BaseException) and not isinstance(r, ESMProteinError):
|
239 |
+
return ESMProteinError(500, str(r))
|
240 |
+
return r
|
241 |
+
|
242 |
+
return [_capture_exception(r) for r in results]
|
243 |
+
|
244 |
+
def __generate_protein(
|
245 |
+
self, input: ESMProtein, config: GenerationConfig
|
246 |
+
) -> ESMProtein | ESMProteinError:
|
247 |
+
req = {}
|
248 |
+
req["sequence"] = input.sequence
|
249 |
+
req["secondary_structure"] = input.secondary_structure
|
250 |
+
req["sasa"] = input.sasa
|
251 |
+
if input.function_annotations is not None:
|
252 |
+
req["function"] = [x.to_tuple() for x in input.function_annotations]
|
253 |
+
req["coordinates"] = maybe_list(input.coordinates, convert_nan_to_none=True)
|
254 |
+
|
255 |
+
request = {
|
256 |
+
"model": self.model,
|
257 |
+
"inputs": req,
|
258 |
+
"track": config.track,
|
259 |
+
"invalid_ids": config.invalid_ids,
|
260 |
+
"schedule": config.schedule,
|
261 |
+
"num_steps": config.num_steps,
|
262 |
+
"temperature": config.temperature,
|
263 |
+
"top_p": config.top_p,
|
264 |
+
"condition_on_coordinates_only": config.condition_on_coordinates_only,
|
265 |
+
}
|
266 |
+
try:
|
267 |
+
data = self._post("generate", request, input.potential_sequence_of_concern)
|
268 |
+
except ESMProteinError as e:
|
269 |
+
return e
|
270 |
+
|
271 |
+
return ESMProtein(
|
272 |
+
sequence=data["outputs"]["sequence"],
|
273 |
+
secondary_structure=data["outputs"]["secondary_structure"],
|
274 |
+
sasa=data["outputs"]["sasa"],
|
275 |
+
function_annotations=_list_to_function_annotations(
|
276 |
+
data["outputs"]["function"]
|
277 |
+
),
|
278 |
+
coordinates=maybe_tensor(
|
279 |
+
data["outputs"]["coordinates"], convert_none_to_nan=True
|
280 |
+
),
|
281 |
+
plddt=maybe_tensor(data["outputs"]["plddt"]),
|
282 |
+
ptm=maybe_tensor(data["outputs"]["ptm"]),
|
283 |
+
)
|
284 |
+
|
285 |
+
def __generate_protein_tensor(
|
286 |
+
self, input: ESMProteinTensor, config: GenerationConfig
|
287 |
+
) -> ESMProteinTensor | ESMProteinError:
|
288 |
+
req = {}
|
289 |
+
req["sequence"] = maybe_list(input.sequence)
|
290 |
+
req["structure"] = maybe_list(input.structure)
|
291 |
+
req["secondary_structure"] = maybe_list(input.secondary_structure)
|
292 |
+
req["sasa"] = maybe_list(input.sasa)
|
293 |
+
req["function"] = maybe_list(input.function)
|
294 |
+
req["coordinates"] = maybe_list(input.coordinates, convert_nan_to_none=True)
|
295 |
+
req["residue_annotation"] = maybe_list(input.residue_annotations)
|
296 |
+
|
297 |
+
request = {
|
298 |
+
"model": self.model,
|
299 |
+
"inputs": req,
|
300 |
+
"track": config.track,
|
301 |
+
"invalid_ids": config.invalid_ids,
|
302 |
+
"schedule": config.schedule,
|
303 |
+
"num_steps": config.num_steps,
|
304 |
+
"temperature": config.temperature,
|
305 |
+
"top_p": config.top_p,
|
306 |
+
"condition_on_coordinates_only": config.condition_on_coordinates_only,
|
307 |
+
}
|
308 |
+
|
309 |
+
try:
|
310 |
+
data = self._post(
|
311 |
+
"generate_tensor", request, input.potential_sequence_of_concern
|
312 |
+
)
|
313 |
+
except ESMProteinError as e:
|
314 |
+
return e
|
315 |
+
|
316 |
+
def _field_to_tensor(field, convert_none_to_nan: bool = False):
|
317 |
+
if field not in data["outputs"]:
|
318 |
+
return None
|
319 |
+
return maybe_tensor(
|
320 |
+
data["outputs"][field], convert_none_to_nan=convert_none_to_nan
|
321 |
+
)
|
322 |
+
|
323 |
+
output = ESMProteinTensor(
|
324 |
+
sequence=_field_to_tensor("sequence"),
|
325 |
+
structure=_field_to_tensor("structure"),
|
326 |
+
secondary_structure=_field_to_tensor("secondary_structure"),
|
327 |
+
sasa=_field_to_tensor("sasa"),
|
328 |
+
function=_field_to_tensor("function"),
|
329 |
+
residue_annotations=_field_to_tensor("residue_annotation"),
|
330 |
+
coordinates=_field_to_tensor("coordinates", convert_none_to_nan=True),
|
331 |
+
)
|
332 |
+
|
333 |
+
return output
|
334 |
+
|
335 |
+
@retry_decorator
|
336 |
+
def forward_and_sample(
|
337 |
+
self, input: ESMProteinTensor, sampling_configuration: SamplingConfig
|
338 |
+
) -> ForwardAndSampleOutput | ESMProteinError:
|
339 |
+
_validate_protein_tensor_input(input)
|
340 |
+
validate_sampling_config(sampling_configuration, on_invalid="raise")
|
341 |
+
|
342 |
+
req = {}
|
343 |
+
sampling_config = {}
|
344 |
+
embedding_config = {
|
345 |
+
"sequence": sampling_configuration.return_mean_embedding,
|
346 |
+
"per_residue": sampling_configuration.return_per_residue_embeddings,
|
347 |
+
}
|
348 |
+
|
349 |
+
req["sequence"] = maybe_list(input.sequence)
|
350 |
+
req["structure"] = maybe_list(input.structure)
|
351 |
+
req["secondary_structure"] = maybe_list(input.secondary_structure)
|
352 |
+
req["sasa"] = maybe_list(input.sasa)
|
353 |
+
req["function"] = maybe_list(input.function)
|
354 |
+
req["coordinates"] = maybe_list(input.coordinates, convert_nan_to_none=True)
|
355 |
+
req["residue_annotation"] = maybe_list(input.residue_annotations)
|
356 |
+
|
357 |
+
def do_track(t: str):
|
358 |
+
track: SamplingTrackConfig | None
|
359 |
+
if (track := getattr(sampling_configuration, t, None)) is None:
|
360 |
+
sampling_config[t] = None
|
361 |
+
else:
|
362 |
+
sampling_config[t] = {
|
363 |
+
"temperature": track.temperature,
|
364 |
+
"top_p": track.top_p,
|
365 |
+
"only_sample_masked_tokens": track.only_sample_masked_tokens,
|
366 |
+
"invalid_ids": track.invalid_ids,
|
367 |
+
"topk_logprobs": track.topk_logprobs,
|
368 |
+
}
|
369 |
+
|
370 |
+
do_track("sequence")
|
371 |
+
do_track("structure")
|
372 |
+
do_track("secondary_structure")
|
373 |
+
do_track("sasa")
|
374 |
+
do_track("function")
|
375 |
+
|
376 |
+
request = {
|
377 |
+
"model": self.model,
|
378 |
+
"inputs": req,
|
379 |
+
"sampling_config": sampling_config,
|
380 |
+
"embedding_config": embedding_config,
|
381 |
+
}
|
382 |
+
try:
|
383 |
+
data = self._post(
|
384 |
+
"forward_and_sample", request, input.potential_sequence_of_concern
|
385 |
+
)
|
386 |
+
except ESMProteinError as e:
|
387 |
+
return e
|
388 |
+
|
389 |
+
def get(k, field):
|
390 |
+
if data[k] is None:
|
391 |
+
return None
|
392 |
+
v = data[k][field]
|
393 |
+
return torch.tensor(v) if v is not None else None
|
394 |
+
|
395 |
+
tokens = ESMProteinTensor(
|
396 |
+
sequence=get("sequence", "tokens"),
|
397 |
+
structure=get("structure", "tokens"),
|
398 |
+
secondary_structure=get("secondary_structure", "tokens"),
|
399 |
+
sasa=get("sasa", "tokens"),
|
400 |
+
function=get("function", "tokens"),
|
401 |
+
)
|
402 |
+
|
403 |
+
def get_track(field):
|
404 |
+
return ForwardTrackData(
|
405 |
+
sequence=get("sequence", field),
|
406 |
+
structure=get("structure", field),
|
407 |
+
secondary_structure=get("secondary_structure", field),
|
408 |
+
sasa=get("sasa", field),
|
409 |
+
function=get("function", field),
|
410 |
+
)
|
411 |
+
|
412 |
+
def operate_on_track(track: ForwardTrackData, fn):
|
413 |
+
apply = lambda x: fn(x) if x is not None else None
|
414 |
+
return ForwardTrackData(
|
415 |
+
sequence=apply(track.sequence),
|
416 |
+
structure=apply(track.structure),
|
417 |
+
secondary_structure=apply(track.secondary_structure),
|
418 |
+
sasa=apply(track.sasa),
|
419 |
+
function=apply(track.function),
|
420 |
+
)
|
421 |
+
|
422 |
+
logprob = get_track("logprobs")
|
423 |
+
output = ForwardAndSampleOutput(
|
424 |
+
protein_tensor=tokens,
|
425 |
+
logprob=logprob,
|
426 |
+
prob=operate_on_track(logprob, torch.exp),
|
427 |
+
entropy=get_track("entropy"),
|
428 |
+
topk_logprob=get_track("topk_logprobs"),
|
429 |
+
topk_tokens=get_track("topk_tokens"),
|
430 |
+
per_residue_embedding=data["embeddings"]["per_residue"],
|
431 |
+
mean_embedding=data["embeddings"]["sequence"],
|
432 |
+
)
|
433 |
+
return output
|
434 |
+
|
435 |
+
@retry_decorator
|
436 |
+
def encode(self, input: ESMProtein) -> ESMProteinTensor | ESMProteinError:
|
437 |
+
tracks = {}
|
438 |
+
tracks["sequence"] = input.sequence
|
439 |
+
tracks["secondary_structure"] = input.secondary_structure
|
440 |
+
tracks["sasa"] = input.sasa
|
441 |
+
if input.function_annotations is not None:
|
442 |
+
tracks["function"] = [x.to_tuple() for x in input.function_annotations]
|
443 |
+
tracks["coordinates"] = maybe_list(input.coordinates, convert_nan_to_none=True)
|
444 |
+
|
445 |
+
request = {"inputs": tracks, "model": self.model}
|
446 |
+
|
447 |
+
try:
|
448 |
+
data = self._post("encode", request, input.potential_sequence_of_concern)
|
449 |
+
except ESMProteinError as e:
|
450 |
+
return e
|
451 |
+
|
452 |
+
return ESMProteinTensor(
|
453 |
+
sequence=maybe_tensor(data["outputs"]["sequence"]),
|
454 |
+
structure=maybe_tensor(data["outputs"]["structure"]),
|
455 |
+
coordinates=maybe_tensor(
|
456 |
+
data["outputs"]["coordinates"], convert_none_to_nan=True
|
457 |
+
),
|
458 |
+
secondary_structure=maybe_tensor(data["outputs"]["secondary_structure"]),
|
459 |
+
sasa=maybe_tensor(data["outputs"]["sasa"]),
|
460 |
+
function=maybe_tensor(data["outputs"]["function"]),
|
461 |
+
residue_annotations=maybe_tensor(data["outputs"]["residue_annotation"]),
|
462 |
+
)
|
463 |
+
|
464 |
+
@retry_decorator
|
465 |
+
def decode(self, input: ESMProteinTensor) -> ESMProtein | ESMProteinError:
|
466 |
+
_validate_protein_tensor_input(input)
|
467 |
+
|
468 |
+
tokens = {}
|
469 |
+
tokens["sequence"] = maybe_list(input.sequence)
|
470 |
+
tokens["structure"] = maybe_list(input.structure)
|
471 |
+
tokens["secondary_structure"] = maybe_list(input.secondary_structure)
|
472 |
+
tokens["sasa"] = maybe_list(input.sasa)
|
473 |
+
tokens["function"] = maybe_list(input.function)
|
474 |
+
tokens["residue_annotation"] = maybe_list(input.residue_annotations)
|
475 |
+
tokens["coordinates"] = maybe_list(input.coordinates, convert_nan_to_none=True)
|
476 |
+
|
477 |
+
request = {"model": self.model, "inputs": tokens}
|
478 |
+
|
479 |
+
try:
|
480 |
+
data = self._post("decode", request, input.potential_sequence_of_concern)
|
481 |
+
except ESMProteinError as e:
|
482 |
+
return e
|
483 |
+
|
484 |
+
return ESMProtein(
|
485 |
+
sequence=data["outputs"]["sequence"],
|
486 |
+
secondary_structure=data["outputs"]["secondary_structure"],
|
487 |
+
sasa=data["outputs"]["sasa"],
|
488 |
+
function_annotations=_list_to_function_annotations(
|
489 |
+
data["outputs"]["function"]
|
490 |
+
),
|
491 |
+
coordinates=maybe_tensor(
|
492 |
+
data["outputs"]["coordinates"], convert_none_to_nan=True
|
493 |
+
),
|
494 |
+
plddt=maybe_tensor(data["outputs"]["plddt"]),
|
495 |
+
ptm=maybe_tensor(data["outputs"]["ptm"]),
|
496 |
+
)
|
497 |
+
|
498 |
+
@retry_decorator
|
499 |
+
def logits(
|
500 |
+
self, input: ESMProteinTensor, config: LogitsConfig = LogitsConfig()
|
501 |
+
) -> LogitsOutput | ESMProteinError:
|
502 |
+
_validate_protein_tensor_input(input)
|
503 |
+
|
504 |
+
# Note: using raw model forwards is discouraged because of the byte size
|
505 |
+
# of the logits.
|
506 |
+
# Please use forward_and_sample instead.
|
507 |
+
req = {}
|
508 |
+
req["sequence"] = maybe_list(input.sequence)
|
509 |
+
req["structure"] = maybe_list(input.structure)
|
510 |
+
req["secondary_structure"] = maybe_list(input.secondary_structure)
|
511 |
+
req["sasa"] = maybe_list(input.sasa)
|
512 |
+
req["function"] = maybe_list(input.function)
|
513 |
+
req["coordinates"] = maybe_list(input.coordinates, convert_nan_to_none=True)
|
514 |
+
req["residue_annotation"] = maybe_list(input.residue_annotations)
|
515 |
+
|
516 |
+
logits_config = {
|
517 |
+
"sequence": config.sequence,
|
518 |
+
"structure": config.structure,
|
519 |
+
"secondary_structure": config.secondary_structure,
|
520 |
+
"sasa": config.sasa,
|
521 |
+
"function": config.function,
|
522 |
+
"residue_annotations": config.residue_annotations,
|
523 |
+
"return_embeddings": config.return_embeddings,
|
524 |
+
}
|
525 |
+
|
526 |
+
request = {"model": self.model, "inputs": req, "logits_config": logits_config}
|
527 |
+
|
528 |
+
try:
|
529 |
+
data = self._post("logits", request, input.potential_sequence_of_concern)
|
530 |
+
except ESMProteinError as e:
|
531 |
+
return e
|
532 |
+
|
533 |
+
def _maybe_logits(track: str):
|
534 |
+
if "logits" in data and track in data["logits"]:
|
535 |
+
return maybe_tensor(data["logits"][track])
|
536 |
+
return None
|
537 |
+
|
538 |
+
output = LogitsOutput(
|
539 |
+
logits=ForwardTrackData(
|
540 |
+
sequence=_maybe_logits("sequence"),
|
541 |
+
structure=_maybe_logits("structure"),
|
542 |
+
secondary_structure=_maybe_logits("secondary_structure"),
|
543 |
+
sasa=_maybe_logits("sasa"),
|
544 |
+
function=_maybe_logits("function"),
|
545 |
+
),
|
546 |
+
embeddings=maybe_tensor(data["embeddings"]),
|
547 |
+
residue_annotation_logits=_maybe_logits("residue_annotation"),
|
548 |
+
)
|
549 |
+
|
550 |
+
return output
|
551 |
+
|
552 |
+
def _post(self, endpoint, request, potential_sequence_of_concern):
|
553 |
+
request["potential_sequence_of_concern"] = potential_sequence_of_concern
|
554 |
+
|
555 |
+
response = requests.post(
|
556 |
+
urljoin(self.url, f"/api/v1/{endpoint}"),
|
557 |
+
json=request,
|
558 |
+
headers=self.headers,
|
559 |
+
timeout=self.request_timeout,
|
560 |
+
)
|
561 |
+
|
562 |
+
if not response.ok:
|
563 |
+
raise ESMProteinError(
|
564 |
+
error_code=response.status_code,
|
565 |
+
error_msg=f"Failure in {endpoint}: {response.text}",
|
566 |
+
)
|
567 |
+
|
568 |
+
data = response.json()
|
569 |
+
# Nextjs puts outputs dict under "data" key.
|
570 |
+
# Lift it up for easier downstream processing.
|
571 |
+
if "outputs" not in data and "data" in data:
|
572 |
+
data = data["data"]
|
573 |
+
|
574 |
+
return data
|
575 |
+
|
576 |
+
@property
|
577 |
+
def raw_model(self):
|
578 |
+
raise NotImplementedError(
|
579 |
+
f"Can not get underlying remote model {self.model} from a Forge client."
|
580 |
+
)
|
Dyna-1/esm/sdk/sagemaker.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import boto3
|
4 |
+
|
5 |
+
from esm.sdk.forge import (
|
6 |
+
ESM3ForgeInferenceClient,
|
7 |
+
SequenceStructureForgeInferenceClient,
|
8 |
+
)
|
9 |
+
|
10 |
+
|
11 |
+
class SequenceStructureSageMakerClient(SequenceStructureForgeInferenceClient):
|
12 |
+
def __init__(self, endpoint_name: str):
|
13 |
+
"""SequenceStructure (folding and inverse folding) client that talks to a SageMaker endpoint.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
endpoint_name: Name of the SageMaker endpoint.
|
17 |
+
"""
|
18 |
+
# Dummy URL and token to make SequenceStructureForgeInferenceClient happy.
|
19 |
+
super().__init__(url="", token="dummy")
|
20 |
+
|
21 |
+
self._endpoint_name = endpoint_name
|
22 |
+
|
23 |
+
self._client = boto3.client(service_name="sagemaker-runtime")
|
24 |
+
|
25 |
+
def _post(self, endpoint, request, potential_sequence_of_concern):
|
26 |
+
request["potential_sequence_of_concern"] = potential_sequence_of_concern
|
27 |
+
request["model"] = request.get("model", None)
|
28 |
+
invocations_request = {
|
29 |
+
# Duplicate these fields at the top level to make Forge requests consistent.
|
30 |
+
"model": request["model"],
|
31 |
+
"request_id": "", # Forge specific field.
|
32 |
+
"user_id": "", # Forge specific field.
|
33 |
+
# Invocation data bits.
|
34 |
+
"api_ver": "v1", # Must be v1 right now.
|
35 |
+
"endpoint": endpoint,
|
36 |
+
# Wrapped request.
|
37 |
+
endpoint: request,
|
38 |
+
}
|
39 |
+
|
40 |
+
try:
|
41 |
+
response = self._client.invoke_endpoint(
|
42 |
+
EndpointName=self._endpoint_name,
|
43 |
+
ContentType="application/json",
|
44 |
+
Body=json.dumps(invocations_request),
|
45 |
+
)
|
46 |
+
except Exception as e:
|
47 |
+
raise RuntimeError(f"Failure in {endpoint}: {e}") from e
|
48 |
+
|
49 |
+
data = json.loads(response["Body"].read().decode())
|
50 |
+
|
51 |
+
# Response must match request.
|
52 |
+
assert (
|
53 |
+
data["endpoint"] == endpoint
|
54 |
+
), f"Response endpoint is {data['endpoint']} but request is {endpoint}"
|
55 |
+
|
56 |
+
# Get the actual responses under the endpoint key.
|
57 |
+
data = data[endpoint]
|
58 |
+
|
59 |
+
return data
|
60 |
+
|
61 |
+
|
62 |
+
class ESM3SageMakerClient(ESM3ForgeInferenceClient):
|
63 |
+
def __init__(self, endpoint_name: str, model: str):
|
64 |
+
"""ESM3 client that talks to a SageMaker endpoint.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
endpoint_name: Name of the SageMaker endpoint.
|
68 |
+
model: Name of the ESM3 model.
|
69 |
+
"""
|
70 |
+
# Dummy URL and token to make ESM3ForgeInferenceClient happy.
|
71 |
+
super().__init__(model=model, url="", token="dummy")
|
72 |
+
|
73 |
+
self._endpoint_name = endpoint_name
|
74 |
+
self._model = model
|
75 |
+
|
76 |
+
self._client = boto3.client(service_name="sagemaker-runtime")
|
77 |
+
|
78 |
+
def _post(self, endpoint, request, potential_sequence_of_concern):
|
79 |
+
request["potential_sequence_of_concern"] = potential_sequence_of_concern
|
80 |
+
|
81 |
+
invocations_request = {
|
82 |
+
# Duplicate these fields at the top level to make Forge requests consistent.
|
83 |
+
"model": request["model"],
|
84 |
+
"request_id": "", # Forge specific field.
|
85 |
+
"user_id": "", # Forge specific field.
|
86 |
+
# Invocation data bits.
|
87 |
+
"api_ver": "v1", # Must be v1 right now.
|
88 |
+
"endpoint": endpoint,
|
89 |
+
# Wrapped request.
|
90 |
+
endpoint: request,
|
91 |
+
}
|
92 |
+
|
93 |
+
try:
|
94 |
+
response = self._client.invoke_endpoint(
|
95 |
+
EndpointName=self._endpoint_name,
|
96 |
+
ContentType="application/json",
|
97 |
+
Body=json.dumps(invocations_request),
|
98 |
+
)
|
99 |
+
except Exception as e:
|
100 |
+
raise RuntimeError(f"Failure in {endpoint}: {e}")
|
101 |
+
|
102 |
+
data = json.loads(response["Body"].read().decode())
|
103 |
+
|
104 |
+
# Response must match request.
|
105 |
+
assert data["endpoint"] == endpoint
|
106 |
+
|
107 |
+
# Get the actual responses under the endpoint key.
|
108 |
+
data = data[endpoint]
|
109 |
+
|
110 |
+
return data
|
Dyna-1/esm/tokenization/__init__.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Protocol
|
3 |
+
|
4 |
+
from esm.utils.constants.models import (
|
5 |
+
ESM3_OPEN_SMALL,
|
6 |
+
normalize_model_name,
|
7 |
+
)
|
8 |
+
|
9 |
+
from .function_tokenizer import InterProQuantizedTokenizer
|
10 |
+
from .residue_tokenizer import ResidueAnnotationsTokenizer
|
11 |
+
from .sasa_tokenizer import SASADiscretizingTokenizer
|
12 |
+
from .sequence_tokenizer import EsmSequenceTokenizer
|
13 |
+
from .ss_tokenizer import SecondaryStructureTokenizer
|
14 |
+
from .structure_tokenizer import StructureTokenizer
|
15 |
+
from .tokenizer_base import EsmTokenizerBase
|
16 |
+
|
17 |
+
|
18 |
+
class TokenizerCollectionProtocol(Protocol):
|
19 |
+
sequence: EsmSequenceTokenizer
|
20 |
+
structure: StructureTokenizer
|
21 |
+
secondary_structure: SecondaryStructureTokenizer
|
22 |
+
sasa: SASADiscretizingTokenizer
|
23 |
+
function: InterProQuantizedTokenizer
|
24 |
+
residue_annotations: ResidueAnnotationsTokenizer
|
25 |
+
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class TokenizerCollection:
|
29 |
+
sequence: EsmSequenceTokenizer
|
30 |
+
structure: StructureTokenizer
|
31 |
+
secondary_structure: SecondaryStructureTokenizer
|
32 |
+
sasa: SASADiscretizingTokenizer
|
33 |
+
function: InterProQuantizedTokenizer
|
34 |
+
residue_annotations: ResidueAnnotationsTokenizer
|
35 |
+
|
36 |
+
|
37 |
+
def get_esm3_model_tokenizers(model: str = ESM3_OPEN_SMALL) -> TokenizerCollection:
|
38 |
+
if normalize_model_name(model) == ESM3_OPEN_SMALL:
|
39 |
+
return TokenizerCollection(
|
40 |
+
sequence=EsmSequenceTokenizer(),
|
41 |
+
structure=StructureTokenizer(),
|
42 |
+
secondary_structure=SecondaryStructureTokenizer(kind="ss8"),
|
43 |
+
sasa=SASADiscretizingTokenizer(),
|
44 |
+
function=InterProQuantizedTokenizer(),
|
45 |
+
residue_annotations=ResidueAnnotationsTokenizer(),
|
46 |
+
)
|
47 |
+
else:
|
48 |
+
raise ValueError(f"Unknown model: {model}")
|
49 |
+
|
50 |
+
|
51 |
+
def get_esmc_model_tokenizers() -> EsmSequenceTokenizer:
|
52 |
+
return EsmSequenceTokenizer()
|
53 |
+
|
54 |
+
|
55 |
+
def get_invalid_tokenizer_ids(tokenizer: EsmTokenizerBase) -> list[int]:
|
56 |
+
if isinstance(tokenizer, EsmSequenceTokenizer):
|
57 |
+
return [
|
58 |
+
tokenizer.mask_token_id, # type: ignore
|
59 |
+
tokenizer.pad_token_id, # type: ignore
|
60 |
+
tokenizer.cls_token_id, # type: ignore
|
61 |
+
tokenizer.eos_token_id, # type: ignore
|
62 |
+
]
|
63 |
+
else:
|
64 |
+
return [
|
65 |
+
tokenizer.mask_token_id,
|
66 |
+
tokenizer.pad_token_id,
|
67 |
+
tokenizer.bos_token_id,
|
68 |
+
tokenizer.eos_token_id,
|
69 |
+
]
|
Dyna-1/esm/tokenization/function_tokenizer.py
ADDED
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Tokenizes annotations of protein function."""
|
2 |
+
|
3 |
+
import re
|
4 |
+
import string
|
5 |
+
from functools import cache, cached_property, partial
|
6 |
+
from typing import Collection
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
import scipy.sparse as sp
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
from esm.tokenization.tokenizer_base import EsmTokenizerBase
|
15 |
+
from esm.utils.constants import esm3 as C
|
16 |
+
from esm.utils.function import interpro, lsh, tfidf
|
17 |
+
from esm.utils.misc import stack_variable_length_tensors
|
18 |
+
from esm.utils.types import FunctionAnnotation, PathLike
|
19 |
+
|
20 |
+
|
21 |
+
def _default_data_path(x: PathLike | None, d: PathLike) -> PathLike:
|
22 |
+
return x if x is not None else C.data_root("esm3") / d
|
23 |
+
|
24 |
+
|
25 |
+
def _default_local_data_path(x: PathLike | None, d: PathLike) -> PathLike:
|
26 |
+
return x if x is not None else d
|
27 |
+
|
28 |
+
|
29 |
+
class InterProQuantizedTokenizer(EsmTokenizerBase):
|
30 |
+
"""Tokenizer for functional annotations.
|
31 |
+
|
32 |
+
This tokenizer converts InterPro and/or function keywords into a multi-token
|
33 |
+
representation by hashing TF-IDF vector representations of the text associated with
|
34 |
+
the fuction and then applying a locality sensitive hash (LSH).
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
depth: int = 8,
|
40 |
+
lsh_bits_per_token: int = 8,
|
41 |
+
lsh_path: PathLike | None = None,
|
42 |
+
keyword_vocabulary_path: PathLike | None = None,
|
43 |
+
keyword_idf_path: PathLike | None = None,
|
44 |
+
interpro_entry_path: PathLike | None = None,
|
45 |
+
interpro2keywords_path: PathLike | None = None,
|
46 |
+
):
|
47 |
+
"""Constructs function tokenizer.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
depth: number of tokens emitted in each position.
|
51 |
+
lsh_bits_per_token: Number of LSH bits per token. Determines the vocabulary
|
52 |
+
size.
|
53 |
+
lsh_path: path to locality sensitive hash (LSH) hyperplanes.
|
54 |
+
keyword_vocabulary_path: path to csv containing function keyword vocabulary.
|
55 |
+
keyword_idf_path: path to IDF values for each keyword.
|
56 |
+
interpro_entry_csv_path: path to list of InterPro entries in CSV format.
|
57 |
+
interpro2keywords_path: path to CSV mapping InterPro IDs to function keywords.
|
58 |
+
"""
|
59 |
+
self.depth = depth
|
60 |
+
|
61 |
+
self.keyword_vocabulary_path = _default_local_data_path(
|
62 |
+
keyword_vocabulary_path, C.KEYWORDS_VOCABULARY
|
63 |
+
)
|
64 |
+
self.keyword_idf_path = _default_local_data_path(
|
65 |
+
keyword_idf_path, C.KEYWORDS_IDF
|
66 |
+
)
|
67 |
+
|
68 |
+
self._interpro2keywords_path = _default_local_data_path(
|
69 |
+
interpro2keywords_path, C.INTERPRO2KEYWORDS
|
70 |
+
)
|
71 |
+
self.interpro_ = interpro.InterPro(
|
72 |
+
entries_path=_default_local_data_path(interpro_entry_path, C.INTERPRO_ENTRY)
|
73 |
+
)
|
74 |
+
|
75 |
+
self.lsh_path = lsh_path
|
76 |
+
self.lsh_bits_per_token = lsh_bits_per_token
|
77 |
+
self.lsh_vocab_size = 1 << lsh_bits_per_token
|
78 |
+
|
79 |
+
# This is the offset into the vocabulary where LSH tokens start.
|
80 |
+
self._lsh_token_vocab_offset = len(self.special_tokens) + 1 # +1 for <none>
|
81 |
+
|
82 |
+
@cached_property
|
83 |
+
def _lsh(self) -> lsh.LSHTokenized:
|
84 |
+
"""Locality sensitive hash for function annotations."""
|
85 |
+
return lsh.LSHTokenized(
|
86 |
+
self.lsh_bits_per_token,
|
87 |
+
len(self.keyword_vocabulary),
|
88 |
+
self.depth,
|
89 |
+
_default_data_path(self.lsh_path, C.LSH_TABLE_PATHS["8bit"]),
|
90 |
+
)
|
91 |
+
|
92 |
+
@cached_property
|
93 |
+
def interpro2keywords(self) -> dict[str, list[str]]:
|
94 |
+
"""Mapping from InterPro ID to function keywords."""
|
95 |
+
df = pd.read_csv(self._interpro2keywords_path)
|
96 |
+
assert "interpro_id" in df.columns and "keywords" in df.columns, df.columns
|
97 |
+
return dict(zip(df.interpro_id, df.keywords.str.split(",")))
|
98 |
+
|
99 |
+
@cached_property
|
100 |
+
def interpro_labels(self) -> list[str]:
|
101 |
+
"""The set of supported InterPro labels."""
|
102 |
+
return sorted(self.interpro2keywords.keys())
|
103 |
+
|
104 |
+
@cached_property
|
105 |
+
def interpro_to_index(self) -> dict[str, int]:
|
106 |
+
"""Mapping from InterPro id to index."""
|
107 |
+
return {id: i for i, id in enumerate(self.interpro_labels)}
|
108 |
+
|
109 |
+
@property
|
110 |
+
def keyword_vocabulary(self) -> list[str]:
|
111 |
+
"""Set of supported keywords."""
|
112 |
+
return self._tfidf.vocabulary
|
113 |
+
|
114 |
+
@property
|
115 |
+
def keyword_to_index(self) -> dict[str, int]:
|
116 |
+
"""Mapping from keywords to index."""
|
117 |
+
return self._tfidf.vocab_to_index
|
118 |
+
|
119 |
+
@cached_property
|
120 |
+
def _tfidf(self) -> tfidf.TFIDFModel:
|
121 |
+
"""Creates TF-IDF model for encoding function keywords."""
|
122 |
+
return tfidf.TFIDFModel(
|
123 |
+
vocabulary_path=self.keyword_vocabulary_path, idf_path=self.keyword_idf_path
|
124 |
+
)
|
125 |
+
|
126 |
+
@cached_property
|
127 |
+
def special_tokens(self) -> list[str]:
|
128 |
+
"""List of special tokens which come before cluster tokens in vocab."""
|
129 |
+
return ["<pad>", "<motif>", "<unk>"]
|
130 |
+
|
131 |
+
@cached_property
|
132 |
+
def vocab(self) -> list[str]:
|
133 |
+
"""Vocabulary of function tokens."""
|
134 |
+
lsh_tokens = [f"<lsh:{i}>" for i in range(self.lsh_vocab_size)]
|
135 |
+
return self.special_tokens + ["<none>"] + lsh_tokens
|
136 |
+
|
137 |
+
@cached_property
|
138 |
+
def vocab_to_index(self) -> dict[str, int]:
|
139 |
+
return {token: token_id for token_id, token in enumerate(self.vocab)}
|
140 |
+
|
141 |
+
def get_special_tokens_mask(self, encoded: torch.Tensor) -> torch.Tensor:
|
142 |
+
"""Determines where in the sequence are special tokens."""
|
143 |
+
where = encoded < len(self.special_tokens)
|
144 |
+
assert torch.all(torch.all(where, dim=1) | torch.all(~where, dim=1))
|
145 |
+
return where[:, 0]
|
146 |
+
|
147 |
+
def tokenize(
|
148 |
+
self,
|
149 |
+
annotations: list[FunctionAnnotation],
|
150 |
+
seqlen: int,
|
151 |
+
p_keyword_dropout: float = 0.0,
|
152 |
+
) -> list[str]:
|
153 |
+
"""Encodes range-annotations of protein function as tokens.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
features: Annotated function ranges, either as InterPro ids or keywords.
|
157 |
+
seqlen: length of sequence.
|
158 |
+
p_keyword_dropout: Optional probability of dropping out keywords from the
|
159 |
+
input annotations.
|
160 |
+
Returns:
|
161 |
+
Tokenized representation of function annotations as a list of string tokens
|
162 |
+
of size seqlen.
|
163 |
+
"""
|
164 |
+
assert seqlen >= 0
|
165 |
+
|
166 |
+
if not annotations:
|
167 |
+
return ["<pad>"] * seqlen
|
168 |
+
|
169 |
+
# Expand the range annotations into positional annotaiton sets.
|
170 |
+
positional_labels: list[set[str]] = [set() for _ in range(seqlen)]
|
171 |
+
for annotation in annotations:
|
172 |
+
assert 1 <= annotation.start <= annotation.end <= seqlen, (
|
173 |
+
f"Invalid annotation range [{annotation.start}, {annotation.end}] for "
|
174 |
+
f"sequence length {seqlen}."
|
175 |
+
)
|
176 |
+
for i in range(annotation.start - 1, annotation.end):
|
177 |
+
positional_labels[i].add(annotation.label)
|
178 |
+
|
179 |
+
if p_keyword_dropout > 0:
|
180 |
+
keyword_mask = (
|
181 |
+
np.random.random(len(self._tfidf.vocabulary)) < p_keyword_dropout
|
182 |
+
)
|
183 |
+
else:
|
184 |
+
keyword_mask = None
|
185 |
+
|
186 |
+
# Annotations tend to be repetitive over the length of the sequence - cache their
|
187 |
+
# hashes to speed up tokenization.
|
188 |
+
hash_fn = cache(partial(self._function_text_hash, keyword_mask=keyword_mask))
|
189 |
+
|
190 |
+
tokens: list[str] = []
|
191 |
+
for labels in positional_labels:
|
192 |
+
if not labels:
|
193 |
+
token = "<none>"
|
194 |
+
else:
|
195 |
+
lsh_hash = hash_fn(frozenset(labels))
|
196 |
+
if lsh_hash is not None:
|
197 |
+
assert len(lsh_hash) == self.depth
|
198 |
+
token = "<lsh:" + ",".join(map(str, lsh_hash)) + ">"
|
199 |
+
else:
|
200 |
+
token = "<unk>"
|
201 |
+
|
202 |
+
tokens.append(token)
|
203 |
+
|
204 |
+
return tokens
|
205 |
+
|
206 |
+
def _function_text_hash(
|
207 |
+
self, labels: Collection[str], keyword_mask: np.ndarray | None = None
|
208 |
+
) -> np.ndarray | None:
|
209 |
+
"""Applies a locality sensitive hash (LSH) to function text.
|
210 |
+
|
211 |
+
Args:
|
212 |
+
labels: InterPro ids and/or keywords.
|
213 |
+
keyword_mask: optional boolean array shaped (keyword_vocab_size,) indicating
|
214 |
+
which keywords to drop before hashing.
|
215 |
+
Returns:
|
216 |
+
LSH shaped (depth,) or None if there is no text or keywords to hash.
|
217 |
+
"""
|
218 |
+
# Split labels into either InterPro ids or keywords.
|
219 |
+
interpro_ids = []
|
220 |
+
keywords = []
|
221 |
+
for label in labels:
|
222 |
+
match = re.search(r"IPR\d+", label)
|
223 |
+
if match and match.group() in self.interpro_to_index:
|
224 |
+
interpro_ids.append(match.group())
|
225 |
+
elif label in self._tfidf.vocab_to_index:
|
226 |
+
keywords.append(label)
|
227 |
+
else:
|
228 |
+
raise ValueError(f"Unsupported: {label}")
|
229 |
+
|
230 |
+
vec: sp.csr_matrix = self._tfidf.encode(keywords)
|
231 |
+
|
232 |
+
# Perform an element-wise maximum over TF-IDF vectors from distinct tags to
|
233 |
+
# avoid tags getting "washed out" by eg. 4 very similar tags. Keywords are
|
234 |
+
# incorporated as another TF-IDF vector
|
235 |
+
vec: sp.csr_matrix = self._tfidf.encode(keywords)
|
236 |
+
for interpro_id in interpro_ids:
|
237 |
+
interpro_keywords = self.interpro2keywords.get(interpro_id, [])
|
238 |
+
vec_ = self._tfidf.encode(interpro_keywords)
|
239 |
+
vec = vec.maximum(vec_)
|
240 |
+
|
241 |
+
if keyword_mask is not None:
|
242 |
+
vec.data *= 1 - np.take(keyword_mask, vec.indices)
|
243 |
+
|
244 |
+
if vec.sum() == 0:
|
245 |
+
return None
|
246 |
+
|
247 |
+
return self._lsh(vec)[0, :]
|
248 |
+
|
249 |
+
def encode(
|
250 |
+
self, tokens: list[str], add_special_tokens: bool = True
|
251 |
+
) -> torch.Tensor:
|
252 |
+
"""Encodes string tokens as token-id tensor.
|
253 |
+
|
254 |
+
Args:
|
255 |
+
tokens: list of individual tokens. e.g. ["<none>", "<pq:1,2,3,4>"]
|
256 |
+
add_special_tokens: whether to add a single pad token at the start and end
|
257 |
+
of the sequence to act as <cls> and <eos> tokens.
|
258 |
+
Returns:
|
259 |
+
<int>[length, depth] function tokens. Length will be +2 of input tokens
|
260 |
+
length when add_special_tokens is True.
|
261 |
+
"""
|
262 |
+
token_ids = torch.zeros(size=(len(tokens), self.depth), dtype=torch.int64)
|
263 |
+
for i, token in enumerate(tokens):
|
264 |
+
token_ids[i, :] = torch.tensor(self._token2ids(token))
|
265 |
+
if add_special_tokens:
|
266 |
+
token_ids = F.pad(
|
267 |
+
token_ids, (0, 0, 1, 1), value=self.vocab_to_index["<pad>"]
|
268 |
+
)
|
269 |
+
return token_ids
|
270 |
+
|
271 |
+
def lookup_annotation_name(self, annotation: FunctionAnnotation) -> str | None:
|
272 |
+
return self.interpro_.lookup_name(annotation.label)
|
273 |
+
|
274 |
+
def format_annotation(self, annotation: FunctionAnnotation) -> str:
|
275 |
+
annotation_name = self.lookup_annotation_name(annotation)
|
276 |
+
if annotation_name is not None:
|
277 |
+
return f"{annotation_name} ({annotation.label})"
|
278 |
+
else:
|
279 |
+
return annotation.label
|
280 |
+
|
281 |
+
def _token2ids(self, token: str) -> list[int]:
|
282 |
+
"""Converts token into token_id set of length depth."""
|
283 |
+
if re.match(r"<lsh:[\d+,]+>", token):
|
284 |
+
lsh_ids = [int(lsh_id) for lsh_id in re.findall(r"\d+", token)]
|
285 |
+
assert (
|
286 |
+
len(lsh_ids) == self.depth
|
287 |
+
), f"Expected token to have {self.depth} ids found {lsh_ids}"
|
288 |
+
return [self._lsh_token_vocab_offset + lsh_id for lsh_id in lsh_ids]
|
289 |
+
elif token == "<none>" or token in self.special_tokens:
|
290 |
+
return [self.vocab_to_index[token]] * self.depth
|
291 |
+
else:
|
292 |
+
raise ValueError(f"Unknown token: {token}")
|
293 |
+
|
294 |
+
def batch_encode(
|
295 |
+
self, token_batch: list[list[str]], add_special_tokens: bool = True
|
296 |
+
) -> torch.Tensor:
|
297 |
+
"""Encodes batch of function tokens.
|
298 |
+
|
299 |
+
Args:
|
300 |
+
token_batch: batch of function tokens.
|
301 |
+
add_special_tokens: whether to add special tokens.
|
302 |
+
Returns:
|
303 |
+
<int>[batch_size, max_length, depth] batch of encoded tokens.
|
304 |
+
"""
|
305 |
+
encoded = [
|
306 |
+
self.encode(tokens, add_special_tokens=add_special_tokens)
|
307 |
+
for tokens in token_batch
|
308 |
+
]
|
309 |
+
return stack_variable_length_tensors(
|
310 |
+
encoded, constant_value=self.vocab_to_index["<pad>"]
|
311 |
+
)
|
312 |
+
|
313 |
+
def decode(self, encoded: torch.Tensor):
|
314 |
+
raise NotImplementedError(
|
315 |
+
"Function token decoding should be handled with "
|
316 |
+
"util.decoding.decode_function_annotations"
|
317 |
+
)
|
318 |
+
|
319 |
+
@property
|
320 |
+
def mask_token(self) -> str:
|
321 |
+
return "<pad>"
|
322 |
+
|
323 |
+
@property
|
324 |
+
def mask_token_id(self) -> int:
|
325 |
+
return self.vocab_to_index[self.mask_token]
|
326 |
+
|
327 |
+
@property
|
328 |
+
def bos_token(self) -> str:
|
329 |
+
return "<pad>"
|
330 |
+
|
331 |
+
@property
|
332 |
+
def bos_token_id(self) -> int:
|
333 |
+
return self.vocab_to_index[self.bos_token]
|
334 |
+
|
335 |
+
@property
|
336 |
+
def eos_token(self) -> str:
|
337 |
+
return "<pad>"
|
338 |
+
|
339 |
+
@property
|
340 |
+
def eos_token_id(self) -> int:
|
341 |
+
return self.vocab_to_index[self.eos_token]
|
342 |
+
|
343 |
+
@property
|
344 |
+
def pad_token(self) -> str:
|
345 |
+
return "<pad>"
|
346 |
+
|
347 |
+
@property
|
348 |
+
def pad_token_id(self) -> int:
|
349 |
+
return self.vocab_to_index[self.pad_token]
|
350 |
+
|
351 |
+
@property
|
352 |
+
def chain_break_token(self) -> str:
|
353 |
+
return "<pad>"
|
354 |
+
|
355 |
+
@property
|
356 |
+
def chain_break_token_id(self) -> int:
|
357 |
+
return self.vocab_to_index[self.chain_break_token]
|
358 |
+
|
359 |
+
@property
|
360 |
+
def all_token_ids(self):
|
361 |
+
return list(range(len(self.vocab)))
|
362 |
+
|
363 |
+
@property
|
364 |
+
def special_token_ids(self):
|
365 |
+
return [self.vocab_to_index[token] for token in self.special_tokens]
|
366 |
+
|
367 |
+
|
368 |
+
def _texts_to_keywords(texts: list[str]) -> list[str]:
|
369 |
+
"""Breaks InterPro/GO free-text description set into bag-of-n-grams for n={1,2}.
|
370 |
+
|
371 |
+
Args:
|
372 |
+
texts: collection of text descriptions, i.e. InterPro/GO names.
|
373 |
+
Returns:
|
374 |
+
Collection of terms/n-grams
|
375 |
+
"""
|
376 |
+
keywords = []
|
377 |
+
for text in texts:
|
378 |
+
keywords.extend(_keywords_from_text(text))
|
379 |
+
return keywords
|
380 |
+
|
381 |
+
|
382 |
+
def _keywords_from_text(text: str) -> list[str]:
|
383 |
+
"""Splits text into unigrams and bigrams."""
|
384 |
+
elements = text.split(", ")
|
385 |
+
|
386 |
+
terms = []
|
387 |
+
for element in elements:
|
388 |
+
element = _sanitize(element)
|
389 |
+
words = element.split()
|
390 |
+
|
391 |
+
# Add 1-mers
|
392 |
+
terms.extend(words)
|
393 |
+
|
394 |
+
# Add 2-mers
|
395 |
+
for i in range(len(words) - 1):
|
396 |
+
bigram = words[i] + " " + words[i + 1]
|
397 |
+
terms.append(bigram)
|
398 |
+
|
399 |
+
return [term for term in terms if len(term) > 1 and term not in _EXCLUDED_TERMS]
|
400 |
+
|
401 |
+
|
402 |
+
def _sanitize(text: str) -> str:
|
403 |
+
text = text.replace("-", " ")
|
404 |
+
text = text.translate(str.maketrans("", "", string.punctuation))
|
405 |
+
text = text.lower()
|
406 |
+
return text
|
407 |
+
|
408 |
+
|
409 |
+
# These terms are omitted from textual representations since they are pervasive and
|
410 |
+
# unspecific to particular protein function.
|
411 |
+
_EXCLUDED_TERMS = {
|
412 |
+
"binding domain",
|
413 |
+
"biological_process",
|
414 |
+
"biological process",
|
415 |
+
"biologicalprocess",
|
416 |
+
"c",
|
417 |
+
"cellular_component",
|
418 |
+
"cellular component",
|
419 |
+
"cellularcomponent",
|
420 |
+
"cellular_process",
|
421 |
+
"cellularprocess",
|
422 |
+
"cellular process",
|
423 |
+
"cellularprocess",
|
424 |
+
"like domain",
|
425 |
+
"molecular function",
|
426 |
+
"molecular_function",
|
427 |
+
"molecularfunction",
|
428 |
+
"n",
|
429 |
+
}
|
Dyna-1/esm/tokenization/residue_tokenizer.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import cached_property
|
2 |
+
from typing import Any
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from cloudpathlib import AnyPath
|
8 |
+
|
9 |
+
from esm.tokenization.tokenizer_base import EsmTokenizerBase
|
10 |
+
from esm.utils.constants import esm3 as C
|
11 |
+
|
12 |
+
Sample = dict[str, Any]
|
13 |
+
|
14 |
+
|
15 |
+
class ResidueAnnotationsTokenizer(EsmTokenizerBase):
|
16 |
+
def __init__(self, csv_path: str | None = None, max_annotations: int = 16):
|
17 |
+
if csv_path is None:
|
18 |
+
csv_path = str(C.data_root("esm3") / C.RESID_CSV)
|
19 |
+
self.csv_path = csv_path
|
20 |
+
self.max_annotations = max_annotations
|
21 |
+
|
22 |
+
@cached_property
|
23 |
+
def _description2label(self) -> dict[str, str]:
|
24 |
+
with AnyPath(self.csv_path).open() as f: # type: ignore
|
25 |
+
df = pd.read_csv(f)
|
26 |
+
return dict(zip(df.label, df.label_clean))
|
27 |
+
|
28 |
+
@cached_property
|
29 |
+
def _labels(self) -> list[str]:
|
30 |
+
with AnyPath(self.csv_path).open() as f: # type: ignore
|
31 |
+
df = pd.read_csv(f)
|
32 |
+
labels = (
|
33 |
+
df.groupby("label_clean")["count"]
|
34 |
+
.sum()
|
35 |
+
.sort_values(ascending=False, kind="stable") # type: ignore
|
36 |
+
.index.tolist()
|
37 |
+
)
|
38 |
+
assert isinstance(labels, list)
|
39 |
+
return labels # type: ignore
|
40 |
+
|
41 |
+
def _description2id(self, description: str) -> int | None:
|
42 |
+
label = self._description2label.get(description)
|
43 |
+
return self._label2id.get(label) # type: ignore
|
44 |
+
|
45 |
+
@cached_property
|
46 |
+
def _label2id(self) -> dict[str, int]:
|
47 |
+
offset = len(self.special_tokens) + 1 # +1 for "<none>"
|
48 |
+
return {label: offset + i for i, label in enumerate(self._labels)}
|
49 |
+
|
50 |
+
@cached_property
|
51 |
+
def special_tokens(self) -> list[str]:
|
52 |
+
"""List of special tokens which come before cluster toknes in vocab."""
|
53 |
+
return ["<pad>", "<motif>", "<unk>"]
|
54 |
+
|
55 |
+
@cached_property
|
56 |
+
def vocab(self):
|
57 |
+
annotation_tokens = [f"<ra:{id}>" for _, id in self._label2id.items()]
|
58 |
+
return self.special_tokens + ["<none>"] + annotation_tokens
|
59 |
+
|
60 |
+
@cached_property
|
61 |
+
def vocab_to_index(self) -> dict[str, int]:
|
62 |
+
return {token: token_id for token_id, token in enumerate(self.vocab)}
|
63 |
+
|
64 |
+
@cached_property
|
65 |
+
def vocabulary(self) -> list[str]:
|
66 |
+
"""Full vocabulary."""
|
67 |
+
return [*self.special_tokens, "<none>", *self._labels]
|
68 |
+
|
69 |
+
def get_special_tokens_mask(self, encoded: torch.Tensor) -> torch.Tensor:
|
70 |
+
"""Determines where in the sequence are special tokens."""
|
71 |
+
return encoded[:, 0] < len(self.special_tokens)
|
72 |
+
|
73 |
+
def tokenize(
|
74 |
+
self, sample: Sample | None, sequence: str, fail_on_mismatch: bool = False
|
75 |
+
) -> list[str]:
|
76 |
+
"""
|
77 |
+
# interpro_site_starts
|
78 |
+
# interpro_site_ends # should always == interpro_site_starts. but I haven't checked overall.
|
79 |
+
# interpro_site_residues # the residue identity of the specfic residue that is annotated. good for a sanity check that parsing occurred correctly.
|
80 |
+
# interpro_site_descriptions
|
81 |
+
# ASSERT (i.e. drop if bad)
|
82 |
+
# interpro_site_residues matches the residue at that position
|
83 |
+
# all these lists ^ above are the same length
|
84 |
+
"""
|
85 |
+
seqlen = len(sequence)
|
86 |
+
assert seqlen >= 0
|
87 |
+
# None mean sequence is *not annotated* - so use full <pad>
|
88 |
+
if sample is None:
|
89 |
+
return ["<pad>"] * seqlen
|
90 |
+
|
91 |
+
if any(
|
92 |
+
sample.get(field) is None
|
93 |
+
for field in [
|
94 |
+
"interpro_site_descriptions",
|
95 |
+
"interpro_site_starts",
|
96 |
+
"interpro_site_ends",
|
97 |
+
"interpro_site_residues",
|
98 |
+
]
|
99 |
+
):
|
100 |
+
return ["<pad>"] * seqlen
|
101 |
+
|
102 |
+
num_annotations = len(sample["interpro_site_descriptions"])
|
103 |
+
if any(
|
104 |
+
len(sample[field]) != num_annotations
|
105 |
+
for field in [
|
106 |
+
"interpro_site_starts",
|
107 |
+
"interpro_site_ends",
|
108 |
+
"interpro_site_residues",
|
109 |
+
]
|
110 |
+
):
|
111 |
+
# mismatched length.
|
112 |
+
return ["<pad>"] * seqlen
|
113 |
+
|
114 |
+
positional_ids = [set() for _ in range(seqlen)]
|
115 |
+
for description, start, end, residues in zip(
|
116 |
+
sample["interpro_site_descriptions"],
|
117 |
+
sample["interpro_site_starts"],
|
118 |
+
sample["interpro_site_ends"],
|
119 |
+
sample["interpro_site_residues"],
|
120 |
+
):
|
121 |
+
try:
|
122 |
+
start = int(start)
|
123 |
+
end = int(end)
|
124 |
+
except (TypeError, ValueError):
|
125 |
+
continue
|
126 |
+
|
127 |
+
# Start / End are 1-indexed [inclusive, inclusive].
|
128 |
+
if start <= 0 or end > seqlen or start > end:
|
129 |
+
print(f"invalid start/end: ({start}, {end}), len: {seqlen}")
|
130 |
+
continue
|
131 |
+
|
132 |
+
if len(residues) != (end - start) + 1:
|
133 |
+
print(f"bad reference residue: {residues}")
|
134 |
+
continue
|
135 |
+
|
136 |
+
token_id = self._description2id(description)
|
137 |
+
if token_id is None:
|
138 |
+
token_id = self.vocab_to_index["<unk>"]
|
139 |
+
|
140 |
+
for i, residue in zip(range(start - 1, end), residues):
|
141 |
+
# If there are any mismatching residues, skip the entire sample.
|
142 |
+
if sequence[i] != residue:
|
143 |
+
if fail_on_mismatch:
|
144 |
+
raise ValueError(
|
145 |
+
f"Residue mismatch at position {i} (1-indexed): {sequence[i]} != {residue}"
|
146 |
+
)
|
147 |
+
return ["<pad>"] * seqlen
|
148 |
+
|
149 |
+
positional_ids[i].add(token_id)
|
150 |
+
|
151 |
+
tokens = []
|
152 |
+
for token_ids in positional_ids:
|
153 |
+
if token_ids:
|
154 |
+
token = "<ra:" + ",".join(str(token_id) for token_id in token_ids) + ">"
|
155 |
+
else:
|
156 |
+
token = "<none>"
|
157 |
+
tokens.append(token)
|
158 |
+
return tokens
|
159 |
+
|
160 |
+
def _token2ids(self, token: str) -> list[int]:
|
161 |
+
if token.startswith("<ra:") and token.endswith(">"):
|
162 |
+
return [int(token_id) for token_id in token[4:-1].split(",")]
|
163 |
+
else:
|
164 |
+
token_id = self.vocab_to_index[token]
|
165 |
+
return [token_id]
|
166 |
+
|
167 |
+
def encode(
|
168 |
+
self, tokens: list[str], add_special_tokens: bool = True
|
169 |
+
) -> torch.Tensor:
|
170 |
+
token_ids = torch.full(
|
171 |
+
size=(len(tokens), self.max_annotations),
|
172 |
+
dtype=torch.int64,
|
173 |
+
fill_value=self.vocab_to_index["<pad>"],
|
174 |
+
)
|
175 |
+
for i, token in enumerate(tokens):
|
176 |
+
ids = self._token2ids(token)[: self.max_annotations]
|
177 |
+
token_ids[i, : len(ids)] = torch.tensor(ids)
|
178 |
+
|
179 |
+
if add_special_tokens:
|
180 |
+
token_ids = F.pad(
|
181 |
+
token_ids, (0, 0, 1, 1), value=self.vocab_to_index["<pad>"]
|
182 |
+
)
|
183 |
+
return token_ids
|
184 |
+
|
185 |
+
def decode(self, encoded: torch.Tensor) -> list[str]:
|
186 |
+
raise NotImplementedError(
|
187 |
+
"Residue annotation decoding should be handled with util.decoding.decode_residue_annotations"
|
188 |
+
)
|
189 |
+
|
190 |
+
@property
|
191 |
+
def mask_token(self) -> str:
|
192 |
+
return "<pad>"
|
193 |
+
|
194 |
+
@property
|
195 |
+
def mask_token_id(self) -> int:
|
196 |
+
return self.vocab_to_index[self.mask_token]
|
197 |
+
|
198 |
+
@property
|
199 |
+
def bos_token(self) -> str:
|
200 |
+
return "<pad>"
|
201 |
+
|
202 |
+
@property
|
203 |
+
def bos_token_id(self) -> int:
|
204 |
+
return self.vocab_to_index[self.bos_token]
|
205 |
+
|
206 |
+
@property
|
207 |
+
def eos_token(self) -> str:
|
208 |
+
return "<pad>"
|
209 |
+
|
210 |
+
@property
|
211 |
+
def eos_token_id(self) -> int:
|
212 |
+
return self.vocab_to_index[self.eos_token]
|
213 |
+
|
214 |
+
@property
|
215 |
+
def pad_token(self) -> str:
|
216 |
+
return "<pad>"
|
217 |
+
|
218 |
+
@property
|
219 |
+
def pad_token_id(self) -> int:
|
220 |
+
return self.vocab_to_index[self.pad_token]
|
221 |
+
|
222 |
+
@property
|
223 |
+
def chain_break_token(self) -> str:
|
224 |
+
return "<pad>"
|
225 |
+
|
226 |
+
@property
|
227 |
+
def chain_break_token_id(self) -> int:
|
228 |
+
return self.vocab_to_index[self.chain_break_token]
|
229 |
+
|
230 |
+
@property
|
231 |
+
def all_token_ids(self):
|
232 |
+
return list(range(len(self.vocab)))
|
233 |
+
|
234 |
+
@property
|
235 |
+
def special_token_ids(self):
|
236 |
+
return [self.vocab_to_index[token] for token in self.special_tokens]
|
Dyna-1/esm/tokenization/sasa_tokenizer.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import cached_property
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from esm.tokenization.tokenizer_base import EsmTokenizerBase
|
6 |
+
from esm.utils.constants import esm3 as C
|
7 |
+
|
8 |
+
|
9 |
+
class SASADiscretizingTokenizer(EsmTokenizerBase):
|
10 |
+
"""Tokenizer for Solvent Accessible Surface Area (SASA)."""
|
11 |
+
|
12 |
+
def __init__(self, boundaries: list[float] = C.SASA_DISCRETIZATION_BOUNDARIES):
|
13 |
+
self._boundaries = sorted(boundaries)
|
14 |
+
|
15 |
+
@cached_property
|
16 |
+
def special_tokens(self) -> list[str]:
|
17 |
+
return ["<pad>", "<motif>", "<unk>"]
|
18 |
+
|
19 |
+
@cached_property
|
20 |
+
def vocab(self) -> list[str]:
|
21 |
+
"""Discrete token vocabulary.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
token vocabulary with ranges represented as "<low-high>".
|
25 |
+
"""
|
26 |
+
boundary_strs = ["0"] + [str(b) for b in self._boundaries] + ["inf"]
|
27 |
+
range_tokens = [
|
28 |
+
f"<{low}-{high}>"
|
29 |
+
for low, high in zip(boundary_strs[:-1], boundary_strs[1:])
|
30 |
+
]
|
31 |
+
return self.special_tokens + range_tokens
|
32 |
+
|
33 |
+
@cached_property
|
34 |
+
def midpoints_tensor(self) -> torch.Tensor:
|
35 |
+
"""Midpoints of the SASA token ranges."""
|
36 |
+
boundaries = [0] + self._boundaries + [self._boundaries[-1] * 2]
|
37 |
+
midpoint_tokens = [
|
38 |
+
(float(high) + float(low)) / 2
|
39 |
+
for low, high in zip(boundaries[:-1], boundaries[1:])
|
40 |
+
]
|
41 |
+
midpoint_tokens = [float("nan"), float("nan"), float("nan")] + midpoint_tokens
|
42 |
+
return torch.Tensor(midpoint_tokens)
|
43 |
+
|
44 |
+
def midpoints(self) -> list[float]:
|
45 |
+
"""Midpoints of the SASA token ranges."""
|
46 |
+
return self.midpoints_tensor.tolist()
|
47 |
+
|
48 |
+
@cached_property
|
49 |
+
def vocab_to_index(self) -> dict[str, int]:
|
50 |
+
"""Constructs token -> token id mapping."""
|
51 |
+
return {word: i for i, word in enumerate(self.vocab)}
|
52 |
+
|
53 |
+
def get_special_tokens_mask(self, tokens: torch.Tensor) -> torch.Tensor:
|
54 |
+
"""Determines which positions are special tokens.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
tokens: <int>[length]
|
58 |
+
Returns:
|
59 |
+
<bool>[length] tensor, true where special tokens are located in the input.
|
60 |
+
"""
|
61 |
+
return tokens < len(self.special_tokens)
|
62 |
+
|
63 |
+
def encode(
|
64 |
+
self, values: list[float | str], add_special_tokens: bool = True
|
65 |
+
) -> torch.Tensor:
|
66 |
+
"""Encodes SASA values as discrete tokens.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
values: list of either SASA values or individual tokens. For example
|
70 |
+
[1.2, "<pad>", 10.3, <pad>, 0.]
|
71 |
+
Returns:
|
72 |
+
Token ids as tensor. Adds BOS and EOS special tokens.
|
73 |
+
"""
|
74 |
+
ids = []
|
75 |
+
if add_special_tokens:
|
76 |
+
ids.append(self.vocab_to_index["<pad>"]) # BOS
|
77 |
+
for value in values:
|
78 |
+
if isinstance(value, (float, int)):
|
79 |
+
bucket = torch.bucketize(value, torch.tensor(self._boundaries))
|
80 |
+
token_id = len(self.special_tokens) + bucket
|
81 |
+
elif isinstance(value, str):
|
82 |
+
token_id = self.vocab_to_index[value]
|
83 |
+
else:
|
84 |
+
raise TypeError(value)
|
85 |
+
ids.append(token_id)
|
86 |
+
if add_special_tokens:
|
87 |
+
ids.append(self.vocab_to_index["<pad>"]) # EOS
|
88 |
+
|
89 |
+
return torch.tensor(ids, dtype=torch.int64)
|
90 |
+
|
91 |
+
def decode_float(self, encoded: torch.Tensor) -> list[float]:
|
92 |
+
"""Decodes SASA token ids into float values."""
|
93 |
+
decoded = self.midpoints_tensor[encoded.cpu()]
|
94 |
+
nan_mask = torch.isnan(decoded)
|
95 |
+
np_arr = decoded.numpy()
|
96 |
+
np_arr[nan_mask.numpy()] = None
|
97 |
+
return np_arr.tolist()
|
98 |
+
|
99 |
+
def decode(self, encoded: torch.Tensor) -> str:
|
100 |
+
"""Decodes SASA token ids."""
|
101 |
+
return ",".join(self.vocab[i] for i in encoded)
|
102 |
+
|
103 |
+
def decode_list(self, encoded: torch.Tensor) -> list[str]:
|
104 |
+
"""Decodes SASA token ids."""
|
105 |
+
return [self.vocab[i] for i in encoded]
|
106 |
+
|
107 |
+
@property
|
108 |
+
def mask_token(self) -> str:
|
109 |
+
return "<pad>"
|
110 |
+
|
111 |
+
@property
|
112 |
+
def mask_token_id(self) -> int:
|
113 |
+
return self.vocab_to_index[self.mask_token]
|
114 |
+
|
115 |
+
@property
|
116 |
+
def bos_token(self) -> str:
|
117 |
+
return "<pad>"
|
118 |
+
|
119 |
+
@property
|
120 |
+
def bos_token_id(self) -> int:
|
121 |
+
return self.vocab_to_index[self.bos_token]
|
122 |
+
|
123 |
+
@property
|
124 |
+
def eos_token(self) -> str:
|
125 |
+
return "<pad>"
|
126 |
+
|
127 |
+
@property
|
128 |
+
def eos_token_id(self) -> int:
|
129 |
+
return self.vocab_to_index[self.eos_token]
|
130 |
+
|
131 |
+
@property
|
132 |
+
def pad_token(self) -> str:
|
133 |
+
return "<pad>"
|
134 |
+
|
135 |
+
@property
|
136 |
+
def pad_token_id(self) -> int:
|
137 |
+
return self.vocab_to_index[self.pad_token]
|
138 |
+
|
139 |
+
@property
|
140 |
+
def chain_break_token(self) -> str:
|
141 |
+
return "<pad>"
|
142 |
+
|
143 |
+
@property
|
144 |
+
def chain_break_token_id(self) -> int:
|
145 |
+
return self.vocab_to_index[self.chain_break_token]
|
146 |
+
|
147 |
+
@property
|
148 |
+
def all_token_ids(self):
|
149 |
+
return list(range(len(self.vocab)))
|
150 |
+
|
151 |
+
@property
|
152 |
+
def special_token_ids(self):
|
153 |
+
return [self.vocab_to_index[token] for token in self.special_tokens]
|
Dyna-1/esm/tokenization/sequence_tokenizer.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tokenizers import Tokenizer
|
2 |
+
from tokenizers.models import BPE
|
3 |
+
from tokenizers.processors import TemplateProcessing
|
4 |
+
from transformers import PreTrainedTokenizerFast
|
5 |
+
|
6 |
+
from esm.tokenization.tokenizer_base import EsmTokenizerBase
|
7 |
+
from esm.utils.constants import esm3 as C
|
8 |
+
|
9 |
+
|
10 |
+
class EsmSequenceTokenizer(PreTrainedTokenizerFast, EsmTokenizerBase):
|
11 |
+
"""
|
12 |
+
Constructs an ESM tokenizer.
|
13 |
+
"""
|
14 |
+
|
15 |
+
model_input_names = ["sequence_tokens", "attention_mask"]
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
unk_token="<unk>",
|
20 |
+
cls_token="<cls>",
|
21 |
+
pad_token="<pad>",
|
22 |
+
mask_token="<mask>",
|
23 |
+
eos_token="<eos>",
|
24 |
+
chain_break_token="|",
|
25 |
+
**kwargs,
|
26 |
+
):
|
27 |
+
all_tokens = C.SEQUENCE_VOCAB
|
28 |
+
token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)}
|
29 |
+
|
30 |
+
# a character-level tokenizer is the same as BPE with no token merges
|
31 |
+
bpe = BPE(token_to_id, merges=[], unk_token=unk_token)
|
32 |
+
tokenizer = Tokenizer(bpe)
|
33 |
+
special_tokens = [
|
34 |
+
cls_token,
|
35 |
+
pad_token,
|
36 |
+
mask_token,
|
37 |
+
eos_token,
|
38 |
+
chain_break_token,
|
39 |
+
]
|
40 |
+
self.cb_token = chain_break_token
|
41 |
+
additional_special_tokens = [chain_break_token]
|
42 |
+
|
43 |
+
tokenizer.add_special_tokens(special_tokens)
|
44 |
+
|
45 |
+
# This is where we configure the automatic addition of special tokens when we call
|
46 |
+
# tokenizer(text, add_special_tokens=True). Note that you can also configure how two
|
47 |
+
# sequences are merged if you want.
|
48 |
+
tokenizer.post_processor = TemplateProcessing( # type: ignore
|
49 |
+
single="<cls> $A <eos>",
|
50 |
+
special_tokens=[
|
51 |
+
("<cls>", tokenizer.token_to_id("<cls>")),
|
52 |
+
("<eos>", tokenizer.token_to_id("<eos>")),
|
53 |
+
],
|
54 |
+
)
|
55 |
+
super().__init__(
|
56 |
+
tokenizer_object=tokenizer,
|
57 |
+
unk_token=unk_token,
|
58 |
+
cls_token=cls_token,
|
59 |
+
pad_token=pad_token,
|
60 |
+
mask_token=mask_token,
|
61 |
+
eos_token=eos_token,
|
62 |
+
additional_special_tokens=additional_special_tokens,
|
63 |
+
**kwargs,
|
64 |
+
)
|
65 |
+
|
66 |
+
# These are a footgun, we never use the `bos` token anywhere so we're just overriding it here.
|
67 |
+
@property
|
68 |
+
def bos_token(self):
|
69 |
+
return self.cls_token
|
70 |
+
|
71 |
+
@property
|
72 |
+
def bos_token_id(self):
|
73 |
+
return self.cls_token_id
|
74 |
+
|
75 |
+
@property
|
76 |
+
def chain_break_token(self):
|
77 |
+
return self.cb_token
|
78 |
+
|
79 |
+
@property
|
80 |
+
def chain_break_token_id(self):
|
81 |
+
return self.convert_tokens_to_ids(self.chain_break_token)
|
82 |
+
|
83 |
+
@property
|
84 |
+
def all_token_ids(self):
|
85 |
+
return list(range(self.vocab_size))
|
86 |
+
|
87 |
+
@property
|
88 |
+
def special_token_ids(self):
|
89 |
+
return self.all_special_ids
|
Dyna-1/esm/tokenization/ss_tokenizer.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import cached_property
|
2 |
+
from typing import Sequence
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from esm.tokenization.tokenizer_base import EsmTokenizerBase
|
7 |
+
from esm.utils.constants import esm3 as C
|
8 |
+
|
9 |
+
|
10 |
+
class SecondaryStructureTokenizer(EsmTokenizerBase):
|
11 |
+
"""Tokenizer for secondary structure strings."""
|
12 |
+
|
13 |
+
def __init__(self, kind: str = "ss8"):
|
14 |
+
assert kind in ("ss8", "ss3")
|
15 |
+
self.kind = kind
|
16 |
+
|
17 |
+
@property
|
18 |
+
def special_tokens(self) -> list[str]:
|
19 |
+
return ["<pad>", "<motif>", "<unk>"]
|
20 |
+
|
21 |
+
@cached_property
|
22 |
+
def vocab(self):
|
23 |
+
"""Tokenzier vocabulary list."""
|
24 |
+
match self.kind:
|
25 |
+
case "ss8":
|
26 |
+
nonspecial_tokens = list(C.SSE_8CLASS_VOCAB) # "GHITEBSC"
|
27 |
+
case "ss3":
|
28 |
+
nonspecial_tokens = list(C.SSE_3CLASS_VOCAB) # HEC
|
29 |
+
case _:
|
30 |
+
raise ValueError(self.kind)
|
31 |
+
|
32 |
+
# The non-special tokens ids match amino acid tokens ids when possible.
|
33 |
+
return [*self.special_tokens, *nonspecial_tokens]
|
34 |
+
|
35 |
+
@cached_property
|
36 |
+
def vocab_to_index(self) -> dict[str, int]:
|
37 |
+
"""Constructs token -> token id mapping."""
|
38 |
+
return {word: i for i, word in enumerate(self.vocab)}
|
39 |
+
|
40 |
+
def get_special_tokens_mask(self, tokens: torch.Tensor) -> torch.Tensor:
|
41 |
+
"""Determines which positions are special tokens.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
tokens: <int>[length]
|
45 |
+
Returns:
|
46 |
+
<bool>[length] tensor, true where special tokens are located in the input.
|
47 |
+
"""
|
48 |
+
return tokens < len(self.special_tokens)
|
49 |
+
|
50 |
+
def encode(
|
51 |
+
self, sequence: str | Sequence[str], add_special_tokens: bool = True
|
52 |
+
) -> torch.Tensor:
|
53 |
+
"""Encode secondary structure string
|
54 |
+
|
55 |
+
Args:
|
56 |
+
string: secondary structure string e.g. "GHHIT", or as token listk.
|
57 |
+
Returns:
|
58 |
+
<int>[sequence_length] token ids representing. Will add <cls>/<eos>.
|
59 |
+
"""
|
60 |
+
ids = []
|
61 |
+
if add_special_tokens:
|
62 |
+
ids.append(self.vocab_to_index["<pad>"]) # cls
|
63 |
+
for char in sequence:
|
64 |
+
ids.append(self.vocab_to_index[char])
|
65 |
+
if add_special_tokens:
|
66 |
+
ids.append(self.vocab_to_index["<pad>"]) # eos
|
67 |
+
return torch.tensor(ids, dtype=torch.int64)
|
68 |
+
|
69 |
+
def decode(self, encoded: torch.Tensor) -> str:
|
70 |
+
"""Decodes token ids into secondary structure string.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
encoded: <int>[length] token id array.
|
74 |
+
Returns
|
75 |
+
Decoded secondary structure string.
|
76 |
+
"""
|
77 |
+
return "".join(self.vocab[i] for i in encoded)
|
78 |
+
|
79 |
+
@property
|
80 |
+
def mask_token(self) -> str:
|
81 |
+
return "<pad>"
|
82 |
+
|
83 |
+
@property
|
84 |
+
def mask_token_id(self) -> int:
|
85 |
+
return self.vocab_to_index[self.mask_token]
|
86 |
+
|
87 |
+
@property
|
88 |
+
def bos_token(self) -> str:
|
89 |
+
return "<pad>"
|
90 |
+
|
91 |
+
@property
|
92 |
+
def bos_token_id(self) -> int:
|
93 |
+
return self.vocab_to_index[self.bos_token]
|
94 |
+
|
95 |
+
@property
|
96 |
+
def eos_token(self) -> str:
|
97 |
+
return "<pad>"
|
98 |
+
|
99 |
+
@property
|
100 |
+
def eos_token_id(self) -> int:
|
101 |
+
return self.vocab_to_index[self.eos_token]
|
102 |
+
|
103 |
+
@property
|
104 |
+
def pad_token(self) -> str:
|
105 |
+
return "<pad>"
|
106 |
+
|
107 |
+
@property
|
108 |
+
def pad_token_id(self) -> int:
|
109 |
+
return self.vocab_to_index[self.pad_token]
|
110 |
+
|
111 |
+
@property
|
112 |
+
def chain_break_token(self) -> str:
|
113 |
+
return "<pad>"
|
114 |
+
|
115 |
+
@property
|
116 |
+
def chain_break_token_id(self) -> int:
|
117 |
+
return self.vocab_to_index[self.chain_break_token]
|
118 |
+
|
119 |
+
@property
|
120 |
+
def all_token_ids(self):
|
121 |
+
return list(range(len(self.vocab)))
|
122 |
+
|
123 |
+
@property
|
124 |
+
def special_token_ids(self):
|
125 |
+
return [self.vocab_to_index[token] for token in self.special_tokens]
|
Dyna-1/esm/tokenization/structure_tokenizer.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from esm.tokenization.tokenizer_base import EsmTokenizerBase
|
2 |
+
from esm.utils.constants import esm3 as C
|
3 |
+
|
4 |
+
|
5 |
+
class StructureTokenizer(EsmTokenizerBase):
|
6 |
+
"""A convenince class for accessing special token ids of
|
7 |
+
the StructureTokenEncoder and StructureTokenDecoder."""
|
8 |
+
|
9 |
+
def __init__(self, codebook_size: int = C.VQVAE_CODEBOOK_SIZE):
|
10 |
+
self.vq_vae_special_tokens = {
|
11 |
+
"MASK": codebook_size,
|
12 |
+
"EOS": codebook_size + 1,
|
13 |
+
"BOS": codebook_size + 2,
|
14 |
+
"PAD": codebook_size + 3,
|
15 |
+
"CHAINBREAK": codebook_size + 4,
|
16 |
+
}
|
17 |
+
|
18 |
+
def mask_token(self) -> str:
|
19 |
+
raise NotImplementedError(
|
20 |
+
"Structure tokens are defined on 3D coordinates, not strings."
|
21 |
+
)
|
22 |
+
|
23 |
+
@property
|
24 |
+
def mask_token_id(self) -> int:
|
25 |
+
return self.vq_vae_special_tokens["MASK"]
|
26 |
+
|
27 |
+
def bos_token(self) -> str:
|
28 |
+
raise NotImplementedError(
|
29 |
+
"Structure tokens are defined on 3D coordinates, not strings."
|
30 |
+
)
|
31 |
+
|
32 |
+
@property
|
33 |
+
def bos_token_id(self) -> int:
|
34 |
+
return self.vq_vae_special_tokens["BOS"]
|
35 |
+
|
36 |
+
def eos_token(self) -> str:
|
37 |
+
raise NotImplementedError(
|
38 |
+
"Structure tokens are defined on 3D coordinates, not strings."
|
39 |
+
)
|
40 |
+
|
41 |
+
@property
|
42 |
+
def eos_token_id(self) -> int:
|
43 |
+
return self.vq_vae_special_tokens["EOS"]
|
44 |
+
|
45 |
+
def pad_token(self) -> str:
|
46 |
+
raise NotImplementedError(
|
47 |
+
"Structure tokens are defined on 3D coordinates, not strings."
|
48 |
+
)
|
49 |
+
|
50 |
+
@property
|
51 |
+
def pad_token_id(self) -> int:
|
52 |
+
return self.vq_vae_special_tokens["PAD"]
|
53 |
+
|
54 |
+
def chain_break_token(self) -> str:
|
55 |
+
raise NotImplementedError(
|
56 |
+
"Structure tokens are defined on 3D coordinates, not strings."
|
57 |
+
)
|
58 |
+
|
59 |
+
@property
|
60 |
+
def chain_break_token_id(self) -> int:
|
61 |
+
return self.vq_vae_special_tokens["CHAINBREAK"]
|
62 |
+
|
63 |
+
@property
|
64 |
+
def all_token_ids(self):
|
65 |
+
return list(range(C.VQVAE_CODEBOOK_SIZE + len(self.vq_vae_special_tokens)))
|
66 |
+
|
67 |
+
@property
|
68 |
+
def special_token_ids(self):
|
69 |
+
return self.vq_vae_special_tokens.values()
|
70 |
+
|
71 |
+
def encode(self, *args, **kwargs):
|
72 |
+
raise NotImplementedError(
|
73 |
+
"The StructureTokenizer class is provided as a convenience for "
|
74 |
+
"accessing special token ids of the StructureTokenEncoder and StructureTokenDecoder.\n"
|
75 |
+
"Please use them instead."
|
76 |
+
)
|
77 |
+
|
78 |
+
def decode(self, *args, **kwargs):
|
79 |
+
raise NotImplementedError(
|
80 |
+
"The StructureTokenizer class is provided as a convenience for "
|
81 |
+
"accessing special token ids of the StructureTokenEncoder and StructureTokenDecoder.\n"
|
82 |
+
"Please use them instead."
|
83 |
+
)
|
Dyna-1/esm/tokenization/tokenizer_base.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Protocol, runtime_checkable
|
2 |
+
|
3 |
+
|
4 |
+
@runtime_checkable
|
5 |
+
class EsmTokenizerBase(Protocol):
|
6 |
+
def encode(self, *args, **kwargs): ...
|
7 |
+
|
8 |
+
def decode(self, *args, **kwargs): ...
|
9 |
+
|
10 |
+
@property
|
11 |
+
def mask_token(self) -> str: ...
|
12 |
+
|
13 |
+
@property
|
14 |
+
def mask_token_id(self) -> int: ...
|
15 |
+
|
16 |
+
@property
|
17 |
+
def bos_token(self) -> str: ...
|
18 |
+
|
19 |
+
@property
|
20 |
+
def bos_token_id(self) -> int: ...
|
21 |
+
|
22 |
+
@property
|
23 |
+
def eos_token(self) -> str: ...
|
24 |
+
|
25 |
+
@property
|
26 |
+
def eos_token_id(self) -> int: ...
|
27 |
+
|
28 |
+
@property
|
29 |
+
def pad_token(self) -> str: ...
|
30 |
+
|
31 |
+
@property
|
32 |
+
def pad_token_id(self) -> int: ...
|
33 |
+
|
34 |
+
@property
|
35 |
+
def chain_break_token(self) -> str: ...
|
36 |
+
|
37 |
+
@property
|
38 |
+
def chain_break_token_id(self) -> int: ...
|
39 |
+
|
40 |
+
@property
|
41 |
+
def all_token_ids(self): ...
|
42 |
+
|
43 |
+
@property
|
44 |
+
def special_token_ids(self): ...
|
Dyna-1/esm/utils/constants/api.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MAX_TOPK_SEQUENCE = 32
|
2 |
+
MAX_TOPK_STRUCTURE = MAX_TOPK_SEQUENCE
|
3 |
+
MAX_TOPK_SECONDARY_STRUCTURE = MAX_TOPK_SEQUENCE
|
4 |
+
MAX_TOPK_SASA = MAX_TOPK_SEQUENCE
|
5 |
+
MAX_TOPK_FUNCTION = MAX_TOPK_SEQUENCE
|
Dyna-1/esm/utils/constants/esm3.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from functools import cache
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from huggingface_hub import snapshot_download
|
6 |
+
|
7 |
+
SEQUENCE_BOS_TOKEN = 0
|
8 |
+
SEQUENCE_PAD_TOKEN = 1
|
9 |
+
SEQUENCE_EOS_TOKEN = 2
|
10 |
+
SEQUENCE_CHAINBREAK_TOKEN = 31
|
11 |
+
SEQUENCE_MASK_TOKEN = 32
|
12 |
+
|
13 |
+
VQVAE_CODEBOOK_SIZE = 4096
|
14 |
+
VQVAE_SPECIAL_TOKENS = {
|
15 |
+
"MASK": VQVAE_CODEBOOK_SIZE,
|
16 |
+
"EOS": VQVAE_CODEBOOK_SIZE + 1,
|
17 |
+
"BOS": VQVAE_CODEBOOK_SIZE + 2,
|
18 |
+
"PAD": VQVAE_CODEBOOK_SIZE + 3,
|
19 |
+
"CHAINBREAK": VQVAE_CODEBOOK_SIZE + 4,
|
20 |
+
}
|
21 |
+
VQVAE_DIRECTION_LOSS_BINS = 16
|
22 |
+
VQVAE_PAE_BINS = 64
|
23 |
+
VQVAE_MAX_PAE_BIN = 31.0
|
24 |
+
VQVAE_PLDDT_BINS = 50
|
25 |
+
|
26 |
+
STRUCTURE_MASK_TOKEN = VQVAE_SPECIAL_TOKENS["MASK"]
|
27 |
+
STRUCTURE_BOS_TOKEN = VQVAE_SPECIAL_TOKENS["BOS"]
|
28 |
+
STRUCTURE_EOS_TOKEN = VQVAE_SPECIAL_TOKENS["EOS"]
|
29 |
+
STRUCTURE_PAD_TOKEN = VQVAE_SPECIAL_TOKENS["PAD"]
|
30 |
+
STRUCTURE_CHAINBREAK_TOKEN = VQVAE_SPECIAL_TOKENS["CHAINBREAK"]
|
31 |
+
STRUCTURE_UNDEFINED_TOKEN = 955
|
32 |
+
|
33 |
+
SASA_PAD_TOKEN = 0
|
34 |
+
|
35 |
+
SS8_PAD_TOKEN = 0
|
36 |
+
|
37 |
+
INTERPRO_PAD_TOKEN = 0
|
38 |
+
|
39 |
+
RESIDUE_PAD_TOKEN = 0
|
40 |
+
|
41 |
+
CHAIN_BREAK_STR = "|"
|
42 |
+
|
43 |
+
SEQUENCE_BOS_STR = "<cls>"
|
44 |
+
SEQUENCE_EOS_STR = "<eos>"
|
45 |
+
|
46 |
+
MASK_STR_SHORT = "_"
|
47 |
+
SEQUENCE_MASK_STR = "<mask>"
|
48 |
+
SASA_MASK_STR = "<unk>"
|
49 |
+
SS8_MASK_STR = "<unk>"
|
50 |
+
|
51 |
+
# fmt: off
|
52 |
+
SEQUENCE_VOCAB = [
|
53 |
+
"<cls>", "<pad>", "<eos>", "<unk>",
|
54 |
+
"L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K",
|
55 |
+
"Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z",
|
56 |
+
"O", ".", "-", "|",
|
57 |
+
"<mask>",
|
58 |
+
]
|
59 |
+
# fmt: on
|
60 |
+
|
61 |
+
SSE_8CLASS_VOCAB = "GHITEBSC"
|
62 |
+
SSE_3CLASS_VOCAB = "HEC"
|
63 |
+
SSE_8CLASS_TO_3CLASS_MAP = {
|
64 |
+
"G": "H",
|
65 |
+
"H": "H",
|
66 |
+
"I": "H",
|
67 |
+
"T": "C",
|
68 |
+
"E": "E",
|
69 |
+
"B": "E",
|
70 |
+
"S": "C",
|
71 |
+
"C": "C",
|
72 |
+
}
|
73 |
+
|
74 |
+
SASA_DISCRETIZATION_BOUNDARIES = [
|
75 |
+
0.8,
|
76 |
+
4.0,
|
77 |
+
9.6,
|
78 |
+
16.4,
|
79 |
+
24.5,
|
80 |
+
32.9,
|
81 |
+
42.0,
|
82 |
+
51.5,
|
83 |
+
61.2,
|
84 |
+
70.9,
|
85 |
+
81.6,
|
86 |
+
93.3,
|
87 |
+
107.2,
|
88 |
+
125.4,
|
89 |
+
151.4,
|
90 |
+
]
|
91 |
+
|
92 |
+
MAX_RESIDUE_ANNOTATIONS = 16
|
93 |
+
|
94 |
+
|
95 |
+
TFIDF_VECTOR_SIZE = 58641
|
96 |
+
|
97 |
+
|
98 |
+
@staticmethod
|
99 |
+
@cache
|
100 |
+
def data_root(model: str):
|
101 |
+
if "INFRA_PROVIDER" in os.environ:
|
102 |
+
return Path("")
|
103 |
+
# Try to download from hugginface if it doesn't exist
|
104 |
+
if model.startswith("esm3"):
|
105 |
+
path = Path(snapshot_download(repo_id="EvolutionaryScale/esm3-sm-open-v1"))
|
106 |
+
elif model.startswith("esmc-300"):
|
107 |
+
path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-300m-2024-12"))
|
108 |
+
elif model.startswith("esmc-600"):
|
109 |
+
path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-600m-2024-12"))
|
110 |
+
else:
|
111 |
+
raise ValueError(f"{model=} is an invalid model name.")
|
112 |
+
return path
|
113 |
+
|
114 |
+
|
115 |
+
IN_REPO_DATA_FOLDER = Path(__file__).parents[2] / "data"
|
116 |
+
|
117 |
+
INTERPRO_ENTRY = IN_REPO_DATA_FOLDER / "entry_list_safety_29026.list"
|
118 |
+
INTERPRO_HIERARCHY = IN_REPO_DATA_FOLDER / "ParentChildTreeFile.txt"
|
119 |
+
INTERPRO2GO = IN_REPO_DATA_FOLDER / "ParentChildTreeFile.txt"
|
120 |
+
INTERPRO_2ID = "data/tag_dict_4_safety_filtered.json"
|
121 |
+
|
122 |
+
LSH_TABLE_PATHS = {"8bit": "data/hyperplanes_8bit_58641.npz"}
|
123 |
+
|
124 |
+
KEYWORDS_VOCABULARY = (
|
125 |
+
IN_REPO_DATA_FOLDER / "keyword_vocabulary_safety_filtered_58641.txt"
|
126 |
+
)
|
127 |
+
KEYWORDS_IDF = IN_REPO_DATA_FOLDER / "keyword_idf_safety_filtered_58641.npy"
|
128 |
+
|
129 |
+
RESID_CSV = "data/uniref90_and_mgnify90_residue_annotations_gt_1k_proteins.csv"
|
130 |
+
INTERPRO2KEYWORDS = IN_REPO_DATA_FOLDER / "interpro_29026_to_keywords_58641.csv"
|
Dyna-1/esm/utils/constants/models.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Model names
|
2 |
+
ESM3_OPEN_SMALL = "esm3_sm_open_v1"
|
3 |
+
ESM3_OPEN_SMALL_ALIAS_1 = "esm3-open-2024-03"
|
4 |
+
ESM3_OPEN_SMALL_ALIAS_2 = "esm3-sm-open-v1"
|
5 |
+
ESM3_OPEN_SMALL_ALIAS_3 = "esm3-open"
|
6 |
+
ESM3_STRUCTURE_ENCODER_V0 = "esm3_structure_encoder_v0"
|
7 |
+
ESM3_STRUCTURE_DECODER_V0 = "esm3_structure_decoder_v0"
|
8 |
+
ESM3_FUNCTION_DECODER_V0 = "esm3_function_decoder_v0"
|
9 |
+
ESMC_600M = "esmc_600m"
|
10 |
+
ESMC_300M = "esmc_300m"
|
11 |
+
|
12 |
+
|
13 |
+
def model_is_locally_supported(x: str):
|
14 |
+
return x in {
|
15 |
+
ESM3_OPEN_SMALL,
|
16 |
+
ESM3_OPEN_SMALL_ALIAS_1,
|
17 |
+
ESM3_OPEN_SMALL_ALIAS_2,
|
18 |
+
ESM3_OPEN_SMALL_ALIAS_3,
|
19 |
+
}
|
20 |
+
|
21 |
+
|
22 |
+
def normalize_model_name(x: str):
|
23 |
+
if x in {ESM3_OPEN_SMALL_ALIAS_1, ESM3_OPEN_SMALL_ALIAS_2, ESM3_OPEN_SMALL_ALIAS_3}:
|
24 |
+
return ESM3_OPEN_SMALL
|
25 |
+
return x
|
Dyna-1/esm/utils/constants/physics.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
BB_COORDINATES = [
|
2 |
+
[0.5256, 1.3612, 0.0000],
|
3 |
+
[0.0000, 0.0000, 0.0000],
|
4 |
+
[-1.5251, 0.0000, 0.0000],
|
5 |
+
]
|
Dyna-1/esm/utils/decoding.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
from typing import cast
|
3 |
+
|
4 |
+
import attr
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from esm.models.function_decoder import FunctionTokenDecoder
|
8 |
+
from esm.models.vqvae import StructureTokenDecoder
|
9 |
+
from esm.sdk.api import ESMProtein, ESMProteinTensor
|
10 |
+
from esm.tokenization import TokenizerCollectionProtocol
|
11 |
+
from esm.tokenization.function_tokenizer import (
|
12 |
+
InterProQuantizedTokenizer,
|
13 |
+
)
|
14 |
+
from esm.tokenization.residue_tokenizer import (
|
15 |
+
ResidueAnnotationsTokenizer,
|
16 |
+
)
|
17 |
+
from esm.tokenization.sasa_tokenizer import (
|
18 |
+
SASADiscretizingTokenizer,
|
19 |
+
)
|
20 |
+
from esm.tokenization.sequence_tokenizer import (
|
21 |
+
EsmSequenceTokenizer,
|
22 |
+
)
|
23 |
+
from esm.tokenization.ss_tokenizer import (
|
24 |
+
SecondaryStructureTokenizer,
|
25 |
+
)
|
26 |
+
from esm.tokenization.structure_tokenizer import (
|
27 |
+
StructureTokenizer,
|
28 |
+
)
|
29 |
+
from esm.tokenization.tokenizer_base import EsmTokenizerBase
|
30 |
+
from esm.utils.constants import esm3 as C
|
31 |
+
from esm.utils.function.encode_decode import (
|
32 |
+
decode_function_tokens,
|
33 |
+
decode_residue_annotation_tokens,
|
34 |
+
)
|
35 |
+
from esm.utils.misc import maybe_list
|
36 |
+
from esm.utils.structure.protein_chain import ProteinChain
|
37 |
+
from esm.utils.types import FunctionAnnotation
|
38 |
+
|
39 |
+
|
40 |
+
def decode_protein_tensor(
|
41 |
+
input: ESMProteinTensor,
|
42 |
+
tokenizers: TokenizerCollectionProtocol,
|
43 |
+
structure_token_decoder: StructureTokenDecoder,
|
44 |
+
function_token_decoder: FunctionTokenDecoder | None = None,
|
45 |
+
) -> ESMProtein:
|
46 |
+
input = attr.evolve(input) # Make a copy
|
47 |
+
|
48 |
+
sequence = None
|
49 |
+
secondary_structure = None
|
50 |
+
sasa = None
|
51 |
+
function_annotations = []
|
52 |
+
|
53 |
+
coordinates = None
|
54 |
+
|
55 |
+
# If all pad tokens, set to None
|
56 |
+
for track in attr.fields(ESMProteinTensor):
|
57 |
+
tokens: torch.Tensor | None = getattr(input, track.name)
|
58 |
+
if track.name == "coordinates" or track.name == "potential_sequence_of_concern":
|
59 |
+
continue
|
60 |
+
if tokens is not None:
|
61 |
+
tokens = tokens[1:-1] # Remove BOS and EOS tokens
|
62 |
+
tokens = tokens.flatten() # For multi-track tensors
|
63 |
+
track_tokenizer = getattr(tokenizers, track.name)
|
64 |
+
if torch.all(tokens == track_tokenizer.pad_token_id):
|
65 |
+
setattr(input, track.name, None)
|
66 |
+
# If structure track has any mask tokens, do not decode.
|
67 |
+
if track.name == "structure" and torch.any(
|
68 |
+
tokens == track_tokenizer.mask_token_id
|
69 |
+
):
|
70 |
+
setattr(input, track.name, None)
|
71 |
+
|
72 |
+
if input.sequence is not None:
|
73 |
+
sequence = decode_sequence(input.sequence, tokenizers.sequence)
|
74 |
+
|
75 |
+
plddt, ptm = None, None
|
76 |
+
if input.structure is not None:
|
77 |
+
# Note: We give priority to the structure tokens over the coordinates when decoding
|
78 |
+
coordinates, plddt, ptm = decode_structure(
|
79 |
+
structure_tokens=input.structure,
|
80 |
+
structure_decoder=structure_token_decoder,
|
81 |
+
structure_tokenizer=tokenizers.structure,
|
82 |
+
sequence=sequence,
|
83 |
+
)
|
84 |
+
elif input.coordinates is not None:
|
85 |
+
coordinates = input.coordinates[1:-1, ...]
|
86 |
+
|
87 |
+
if input.secondary_structure is not None:
|
88 |
+
secondary_structure = decode_secondary_structure(
|
89 |
+
input.secondary_structure, tokenizers.secondary_structure
|
90 |
+
)
|
91 |
+
if input.sasa is not None:
|
92 |
+
sasa = decode_sasa(input.sasa, tokenizers.sasa)
|
93 |
+
if input.function is not None:
|
94 |
+
if function_token_decoder is None:
|
95 |
+
raise ValueError(
|
96 |
+
"Cannot decode function annotations without a function token decoder"
|
97 |
+
)
|
98 |
+
function_track_annotations = decode_function_annotations(
|
99 |
+
input.function,
|
100 |
+
function_token_decoder=function_token_decoder,
|
101 |
+
function_tokenizer=tokenizers.function,
|
102 |
+
)
|
103 |
+
function_annotations.extend(function_track_annotations)
|
104 |
+
if input.residue_annotations is not None:
|
105 |
+
residue_annotations = decode_residue_annotations(
|
106 |
+
input.residue_annotations, tokenizers.residue_annotations
|
107 |
+
)
|
108 |
+
function_annotations.extend(residue_annotations)
|
109 |
+
|
110 |
+
return ESMProtein(
|
111 |
+
sequence=sequence,
|
112 |
+
secondary_structure=secondary_structure,
|
113 |
+
sasa=sasa, # type: ignore
|
114 |
+
function_annotations=function_annotations if function_annotations else None,
|
115 |
+
coordinates=coordinates,
|
116 |
+
plddt=plddt,
|
117 |
+
ptm=ptm,
|
118 |
+
potential_sequence_of_concern=input.potential_sequence_of_concern,
|
119 |
+
)
|
120 |
+
|
121 |
+
|
122 |
+
def _bos_eos_warn(msg: str, tensor: torch.Tensor, tok: EsmTokenizerBase):
|
123 |
+
if tensor[0] != tok.bos_token_id:
|
124 |
+
warnings.warn(
|
125 |
+
f"{msg} does not start with BOS token, token is ignored. BOS={tok.bos_token_id} vs {tensor}"
|
126 |
+
)
|
127 |
+
if tensor[-1] != tok.eos_token_id:
|
128 |
+
warnings.warn(
|
129 |
+
f"{msg} does not end with EOS token, token is ignored. EOS='{tok.eos_token_id}': {tensor}"
|
130 |
+
)
|
131 |
+
|
132 |
+
|
133 |
+
def decode_sequence(
|
134 |
+
sequence_tokens: torch.Tensor, sequence_tokenizer: EsmSequenceTokenizer, **kwargs
|
135 |
+
) -> str:
|
136 |
+
_bos_eos_warn("Sequence", sequence_tokens, sequence_tokenizer)
|
137 |
+
sequence = sequence_tokenizer.decode(sequence_tokens, **kwargs)
|
138 |
+
sequence = sequence.replace(" ", "")
|
139 |
+
sequence = sequence.replace(sequence_tokenizer.mask_token, C.MASK_STR_SHORT)
|
140 |
+
sequence = sequence.replace(sequence_tokenizer.cls_token, "")
|
141 |
+
sequence = sequence.replace(sequence_tokenizer.eos_token, "")
|
142 |
+
|
143 |
+
return sequence
|
144 |
+
|
145 |
+
|
146 |
+
def decode_structure(
|
147 |
+
structure_tokens: torch.Tensor,
|
148 |
+
structure_decoder: StructureTokenDecoder,
|
149 |
+
structure_tokenizer: StructureTokenizer,
|
150 |
+
sequence: str | None = None,
|
151 |
+
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
152 |
+
is_singleton = len(structure_tokens.size()) == 1
|
153 |
+
if is_singleton:
|
154 |
+
structure_tokens = structure_tokens.unsqueeze(0)
|
155 |
+
else:
|
156 |
+
raise ValueError(
|
157 |
+
f"Only one structure can be decoded at a time, got structure tokens of shape {structure_tokens.size()}"
|
158 |
+
)
|
159 |
+
_bos_eos_warn("Structure", structure_tokens[0], structure_tokenizer)
|
160 |
+
|
161 |
+
decoder_output = structure_decoder.decode(structure_tokens)
|
162 |
+
bb_coords: torch.Tensor = decoder_output["bb_pred"][
|
163 |
+
0, 1:-1, ...
|
164 |
+
] # Remove BOS and EOS tokens
|
165 |
+
bb_coords = bb_coords.detach().cpu()
|
166 |
+
|
167 |
+
if "plddt" in decoder_output:
|
168 |
+
plddt = decoder_output["plddt"][0, 1:-1]
|
169 |
+
plddt = plddt.detach().cpu()
|
170 |
+
else:
|
171 |
+
plddt = None
|
172 |
+
|
173 |
+
if "ptm" in decoder_output:
|
174 |
+
ptm = decoder_output["ptm"]
|
175 |
+
else:
|
176 |
+
ptm = None
|
177 |
+
|
178 |
+
chain = ProteinChain.from_backbone_atom_coordinates(bb_coords, sequence=sequence)
|
179 |
+
chain = chain.infer_oxygen()
|
180 |
+
return torch.tensor(chain.atom37_positions), plddt, ptm
|
181 |
+
|
182 |
+
|
183 |
+
def decode_secondary_structure(
|
184 |
+
secondary_structure_tokens: torch.Tensor, ss_tokenizer: SecondaryStructureTokenizer
|
185 |
+
) -> str:
|
186 |
+
_bos_eos_warn("Secondary structure", secondary_structure_tokens, ss_tokenizer)
|
187 |
+
secondary_structure_tokens = secondary_structure_tokens[1:-1]
|
188 |
+
secondary_structure = ss_tokenizer.decode(secondary_structure_tokens)
|
189 |
+
return secondary_structure
|
190 |
+
|
191 |
+
|
192 |
+
def decode_sasa(
|
193 |
+
sasa_tokens: torch.Tensor, sasa_tokenizer: SASADiscretizingTokenizer
|
194 |
+
) -> list[float]:
|
195 |
+
if sasa_tokens[0] != 0:
|
196 |
+
raise ValueError("SASA does not start with 0 corresponding to BOS token")
|
197 |
+
if sasa_tokens[-1] != 0:
|
198 |
+
raise ValueError("SASA does not end with 0 corresponding to EOS token")
|
199 |
+
sasa_tokens = sasa_tokens[1:-1]
|
200 |
+
if sasa_tokens.dtype in [
|
201 |
+
torch.int8,
|
202 |
+
torch.int16,
|
203 |
+
torch.int32,
|
204 |
+
torch.int64,
|
205 |
+
torch.long,
|
206 |
+
]:
|
207 |
+
# Decode if int
|
208 |
+
# handles turning NaN's into None's
|
209 |
+
sasa = sasa_tokenizer.decode_float(sasa_tokens)
|
210 |
+
else:
|
211 |
+
# If already float, just convert to list
|
212 |
+
sasa = cast(list[float], maybe_list(sasa_tokens, convert_nan_to_none=True))
|
213 |
+
|
214 |
+
return sasa
|
215 |
+
|
216 |
+
|
217 |
+
def decode_function_annotations(
|
218 |
+
function_annotation_tokens: torch.Tensor,
|
219 |
+
function_token_decoder: FunctionTokenDecoder,
|
220 |
+
function_tokenizer: InterProQuantizedTokenizer,
|
221 |
+
**kwargs,
|
222 |
+
) -> list[FunctionAnnotation]:
|
223 |
+
# No need to check for BOS/EOS as function annotations are not affected
|
224 |
+
|
225 |
+
function_annotations = decode_function_tokens(
|
226 |
+
function_annotation_tokens,
|
227 |
+
function_token_decoder=function_token_decoder,
|
228 |
+
function_tokens_tokenizer=function_tokenizer,
|
229 |
+
**kwargs,
|
230 |
+
)
|
231 |
+
return function_annotations
|
232 |
+
|
233 |
+
|
234 |
+
def decode_residue_annotations(
|
235 |
+
residue_annotation_tokens: torch.Tensor,
|
236 |
+
residue_annotation_decoder: ResidueAnnotationsTokenizer,
|
237 |
+
) -> list[FunctionAnnotation]:
|
238 |
+
# No need to check for BOS/EOS as function annotations are not affected
|
239 |
+
|
240 |
+
residue_annotations = decode_residue_annotation_tokens(
|
241 |
+
residue_annotations_token_ids=residue_annotation_tokens,
|
242 |
+
residue_annotations_tokenizer=residue_annotation_decoder,
|
243 |
+
)
|
244 |
+
return residue_annotations
|
Dyna-1/esm/utils/encoding.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Sequence
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from esm.models.vqvae import StructureTokenEncoder
|
7 |
+
from esm.tokenization.function_tokenizer import (
|
8 |
+
InterProQuantizedTokenizer as EsmFunctionTokenizer,
|
9 |
+
)
|
10 |
+
|
11 |
+
from esm.tokenization.residue_tokenizer import (
|
12 |
+
ResidueAnnotationsTokenizer,
|
13 |
+
)
|
14 |
+
from esm.tokenization.sasa_tokenizer import (
|
15 |
+
SASADiscretizingTokenizer,
|
16 |
+
)
|
17 |
+
from esm.tokenization.sequence_tokenizer import (
|
18 |
+
EsmSequenceTokenizer,
|
19 |
+
)
|
20 |
+
from esm.tokenization.ss_tokenizer import (
|
21 |
+
SecondaryStructureTokenizer,
|
22 |
+
)
|
23 |
+
from esm.tokenization.structure_tokenizer import (
|
24 |
+
StructureTokenizer,
|
25 |
+
)
|
26 |
+
from esm.utils.constants import esm3 as C
|
27 |
+
from esm.utils.function.encode_decode import (
|
28 |
+
encode_function_annotations,
|
29 |
+
)
|
30 |
+
from esm.utils.structure.protein_chain import ProteinChain
|
31 |
+
from esm.utils.types import FunctionAnnotation
|
32 |
+
|
33 |
+
|
34 |
+
# Raw Defaults
|
35 |
+
def get_default_sequence(sequence_length: int) -> str:
|
36 |
+
return C.MASK_STR_SHORT * sequence_length
|
37 |
+
|
38 |
+
|
39 |
+
def get_default_secondary_structure(sequence_length: int) -> str:
|
40 |
+
return C.MASK_STR_SHORT * sequence_length
|
41 |
+
|
42 |
+
|
43 |
+
def get_default_sasa(sequence_length: int) -> Sequence[float | str | None]:
|
44 |
+
return [None] * sequence_length
|
45 |
+
|
46 |
+
|
47 |
+
# Tokenization
|
48 |
+
def tokenize_sequence(
|
49 |
+
sequence: str,
|
50 |
+
sequence_tokenizer: EsmSequenceTokenizer,
|
51 |
+
add_special_tokens: bool = True,
|
52 |
+
) -> torch.Tensor:
|
53 |
+
sequence = sequence.replace(C.MASK_STR_SHORT, sequence_tokenizer.mask_token)
|
54 |
+
sequence_tokens = sequence_tokenizer.encode(
|
55 |
+
sequence, add_special_tokens=add_special_tokens
|
56 |
+
)
|
57 |
+
sequence_tokens = torch.tensor(sequence_tokens, dtype=torch.int64)
|
58 |
+
return sequence_tokens
|
59 |
+
|
60 |
+
|
61 |
+
def tokenize_structure(
|
62 |
+
coordinates: torch.Tensor,
|
63 |
+
structure_encoder: StructureTokenEncoder,
|
64 |
+
structure_tokenizer: StructureTokenizer,
|
65 |
+
reference_sequence: str = "",
|
66 |
+
add_special_tokens: bool = True,
|
67 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
68 |
+
device = next(structure_encoder.parameters()).device
|
69 |
+
chain = ProteinChain.from_atom37(
|
70 |
+
coordinates, sequence=reference_sequence if reference_sequence else None
|
71 |
+
)
|
72 |
+
|
73 |
+
# Setup padding
|
74 |
+
if reference_sequence and len(reference_sequence) != coordinates.size(0):
|
75 |
+
raise ValueError(
|
76 |
+
f"Reference sequence length ({len(reference_sequence)}) does not match the number of residues in the coordinates ({coordinates.size(0)})"
|
77 |
+
)
|
78 |
+
|
79 |
+
left_pad = 0
|
80 |
+
right_pad = 0
|
81 |
+
|
82 |
+
if add_special_tokens:
|
83 |
+
left_pad += 1 # Add space for BOS token
|
84 |
+
right_pad += 1 # Add space for EOS token
|
85 |
+
|
86 |
+
coordinates, plddt, residue_index = chain.to_structure_encoder_inputs()
|
87 |
+
coordinates = coordinates.to(device) # (1, L, 37, 3)
|
88 |
+
plddt = plddt.to(device) # (1, L)
|
89 |
+
residue_index = residue_index.to(device) # (1, L)
|
90 |
+
_, structure_tokens = structure_encoder.encode(
|
91 |
+
coordinates, residue_index=residue_index
|
92 |
+
)
|
93 |
+
coordinates = torch.squeeze(coordinates, dim=0) # (L, 37, 3) # type: ignore
|
94 |
+
plddt = torch.squeeze(plddt, dim=0) # (L,) # type: ignore
|
95 |
+
structure_tokens = torch.squeeze(structure_tokens, dim=0) # (L,) # type: ignore
|
96 |
+
|
97 |
+
# Add space for BOS and EOS tokens
|
98 |
+
if add_special_tokens:
|
99 |
+
coordinates = F.pad(
|
100 |
+
coordinates, (0, 0, 0, 0, left_pad, right_pad), value=torch.inf
|
101 |
+
)
|
102 |
+
plddt = F.pad(plddt, (left_pad, right_pad), value=0)
|
103 |
+
structure_tokens = F.pad(
|
104 |
+
structure_tokens,
|
105 |
+
(left_pad, right_pad),
|
106 |
+
value=structure_tokenizer.mask_token_id,
|
107 |
+
)
|
108 |
+
structure_tokens[0] = structure_tokenizer.bos_token_id
|
109 |
+
structure_tokens[-1] = structure_tokenizer.eos_token_id
|
110 |
+
return coordinates, plddt, structure_tokens
|
111 |
+
|
112 |
+
|
113 |
+
def tokenize_secondary_structure(
|
114 |
+
secondary_structure: str | Sequence[str],
|
115 |
+
secondary_structure_tokenizer: SecondaryStructureTokenizer,
|
116 |
+
add_special_tokens: bool = True,
|
117 |
+
) -> torch.Tensor:
|
118 |
+
if isinstance(secondary_structure, str):
|
119 |
+
# Ensure only one char per token
|
120 |
+
secondary_structure = secondary_structure.replace(
|
121 |
+
secondary_structure_tokenizer.mask_token, C.MASK_STR_SHORT
|
122 |
+
)
|
123 |
+
|
124 |
+
# Input as list of chars
|
125 |
+
secondary_structure = [char for char in secondary_structure]
|
126 |
+
|
127 |
+
# Use tokenizer's mask token
|
128 |
+
secondary_structure = [
|
129 |
+
secondary_structure_tokenizer.mask_token if char == C.MASK_STR_SHORT else char
|
130 |
+
for char in secondary_structure
|
131 |
+
]
|
132 |
+
|
133 |
+
secondary_structure_tokens = secondary_structure_tokenizer.encode(
|
134 |
+
secondary_structure, add_special_tokens=add_special_tokens
|
135 |
+
)
|
136 |
+
return secondary_structure_tokens
|
137 |
+
|
138 |
+
|
139 |
+
def tokenize_sasa(
|
140 |
+
sasa: Sequence[float | str | None],
|
141 |
+
sasa_tokenizer: SASADiscretizingTokenizer,
|
142 |
+
add_special_tokens: bool = True,
|
143 |
+
):
|
144 |
+
sasa_tokens = sasa_tokenizer.encode(
|
145 |
+
[sasa_tokenizer.mask_token if value is None else value for value in sasa],
|
146 |
+
add_special_tokens=add_special_tokens,
|
147 |
+
)
|
148 |
+
return sasa_tokens
|
149 |
+
|
150 |
+
|
151 |
+
def tokenize_function_annotations(
|
152 |
+
function_annotations: Sequence[FunctionAnnotation],
|
153 |
+
reference_sequence: str,
|
154 |
+
function_tokenizer: EsmFunctionTokenizer,
|
155 |
+
residue_annotation_tokenizer: ResidueAnnotationsTokenizer,
|
156 |
+
add_special_tokens: bool = True,
|
157 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
158 |
+
function_tokens, residue_annotation_tokens = encode_function_annotations(
|
159 |
+
sequence=reference_sequence,
|
160 |
+
function_annotations=function_annotations,
|
161 |
+
function_tokens_tokenizer=function_tokenizer,
|
162 |
+
residue_annotations_tokenizer=residue_annotation_tokenizer,
|
163 |
+
add_special_tokens=add_special_tokens,
|
164 |
+
)
|
165 |
+
return function_tokens, residue_annotation_tokens
|
166 |
+
|
167 |
+
|
168 |
+
|
169 |
+
|
170 |
+
# Tokenized Defaults
|
171 |
+
def get_default_sequence_tokens(
|
172 |
+
sequence_length: int, sequence_tokenizer: EsmSequenceTokenizer
|
173 |
+
) -> torch.Tensor:
|
174 |
+
assert sequence_tokenizer.mask_token_id is not None
|
175 |
+
assert sequence_tokenizer.bos_token_id is not None
|
176 |
+
assert sequence_tokenizer.eos_token_id is not None
|
177 |
+
|
178 |
+
sequence_tokens = torch.full(
|
179 |
+
(sequence_length + 2,), sequence_tokenizer.mask_token_id
|
180 |
+
)
|
181 |
+
sequence_tokens[0] = sequence_tokenizer.bos_token_id
|
182 |
+
sequence_tokens[-1] = sequence_tokenizer.eos_token_id
|
183 |
+
|
184 |
+
return sequence_tokens
|
185 |
+
|
186 |
+
|
187 |
+
def get_default_structure_tokens(
|
188 |
+
sequence_length: int, structure_tokenizer: StructureTokenizer
|
189 |
+
) -> torch.Tensor:
|
190 |
+
structure_tokens = (
|
191 |
+
torch.ones((sequence_length + 2,), dtype=torch.int64)
|
192 |
+
* structure_tokenizer.mask_token_id
|
193 |
+
)
|
194 |
+
# Always include BOS and EOS tokens
|
195 |
+
structure_tokens[0] = structure_tokenizer.bos_token_id
|
196 |
+
structure_tokens[-1] = structure_tokenizer.eos_token_id
|
197 |
+
return structure_tokens
|
198 |
+
|
199 |
+
|
200 |
+
def get_default_secondary_structure_tokens(
|
201 |
+
sequence_length: int, secondary_structure_tokenizer: SecondaryStructureTokenizer
|
202 |
+
) -> torch.Tensor:
|
203 |
+
ss8_tokens = torch.full(
|
204 |
+
(sequence_length + 2,), secondary_structure_tokenizer.mask_token_id
|
205 |
+
)
|
206 |
+
ss8_tokens[0] = secondary_structure_tokenizer.bos_token_id
|
207 |
+
ss8_tokens[-1] = secondary_structure_tokenizer.eos_token_id
|
208 |
+
|
209 |
+
return ss8_tokens
|
210 |
+
|
211 |
+
|
212 |
+
def get_default_sasa_tokens(
|
213 |
+
sequence_length: int, sasa_tokenizer: SASADiscretizingTokenizer
|
214 |
+
) -> torch.Tensor:
|
215 |
+
sasa_tokens = torch.full((sequence_length + 2,), sasa_tokenizer.mask_token_id)
|
216 |
+
sasa_tokens[0] = sasa_tokenizer.bos_token_id
|
217 |
+
sasa_tokens[-1] = sasa_tokenizer.eos_token_id
|
218 |
+
return sasa_tokens
|
219 |
+
|
220 |
+
|
221 |
+
def get_default_function_tokens(
|
222 |
+
sequence_length: int, function_tokenizer: EsmFunctionTokenizer
|
223 |
+
) -> torch.Tensor:
|
224 |
+
function_tokens = (
|
225 |
+
torch.ones((sequence_length + 2, function_tokenizer.depth), dtype=torch.int64)
|
226 |
+
* function_tokenizer.pad_token_id
|
227 |
+
)
|
228 |
+
# Always include BOS and EOS tokens
|
229 |
+
function_tokens[0] = function_tokenizer.bos_token_id
|
230 |
+
function_tokens[-1] = function_tokenizer.eos_token_id
|
231 |
+
return function_tokens
|
232 |
+
|
233 |
+
|
234 |
+
def get_default_residue_annotation_tokens(
|
235 |
+
sequence_length: int, residue_annotation_tokenizer: ResidueAnnotationsTokenizer
|
236 |
+
) -> torch.Tensor:
|
237 |
+
residue_annotation_tokens = (
|
238 |
+
torch.ones((sequence_length + 2, C.MAX_RESIDUE_ANNOTATIONS), dtype=torch.int64)
|
239 |
+
* residue_annotation_tokenizer.pad_token_id
|
240 |
+
)
|
241 |
+
# Always include BOS and EOS tokens
|
242 |
+
residue_annotation_tokens[0] = residue_annotation_tokenizer.bos_token_id
|
243 |
+
residue_annotation_tokens[-1] = residue_annotation_tokenizer.eos_token_id
|
244 |
+
return residue_annotation_tokens
|
245 |
+
|
246 |
+
|
Dyna-1/esm/utils/function/encode_decode.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from typing import Sequence
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from esm.models.function_decoder import (
|
7 |
+
FunctionTokenDecoder,
|
8 |
+
merge_annotations,
|
9 |
+
)
|
10 |
+
from esm.tokenization.function_tokenizer import (
|
11 |
+
InterProQuantizedTokenizer,
|
12 |
+
)
|
13 |
+
from esm.tokenization.residue_tokenizer import (
|
14 |
+
ResidueAnnotationsTokenizer,
|
15 |
+
)
|
16 |
+
from esm.utils.constants import esm3 as C
|
17 |
+
from esm.utils.types import FunctionAnnotation
|
18 |
+
|
19 |
+
|
20 |
+
def encode_function_annotations(
|
21 |
+
sequence: str,
|
22 |
+
function_annotations: Sequence[FunctionAnnotation],
|
23 |
+
function_tokens_tokenizer: InterProQuantizedTokenizer,
|
24 |
+
residue_annotations_tokenizer: ResidueAnnotationsTokenizer,
|
25 |
+
add_special_tokens: bool = True,
|
26 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
27 |
+
assert isinstance(
|
28 |
+
residue_annotations_tokenizer, ResidueAnnotationsTokenizer
|
29 |
+
), "residue_annotations_tokenizer must be of type ResidueAnnotationsTokenizer"
|
30 |
+
|
31 |
+
# Split the user's annotations by type
|
32 |
+
ft_annotations: list[FunctionAnnotation] = []
|
33 |
+
ra_annotations: list[FunctionAnnotation] = []
|
34 |
+
for fa in function_annotations:
|
35 |
+
assert (
|
36 |
+
1 <= fa.start <= fa.end <= len(sequence)
|
37 |
+
), f"Invalid (start, end) in function annotation {fa}. Indices 1-indexed and [inclusive, inclusive]"
|
38 |
+
|
39 |
+
supported_label = False
|
40 |
+
|
41 |
+
# Is it an InterPro label?
|
42 |
+
if match := re.search(r"IPR\d+", fa.label):
|
43 |
+
if match.group() in function_tokens_tokenizer.interpro_to_index:
|
44 |
+
ft_annotations.append(fa)
|
45 |
+
supported_label = True
|
46 |
+
|
47 |
+
# Is it a function keyword?
|
48 |
+
if fa.label in function_tokens_tokenizer._tfidf.vocab_to_index:
|
49 |
+
ft_annotations.append(fa)
|
50 |
+
supported_label = True
|
51 |
+
|
52 |
+
# Is it a residue annotation?
|
53 |
+
if fa.label in residue_annotations_tokenizer._labels:
|
54 |
+
ra_annotations.append(fa)
|
55 |
+
supported_label = True
|
56 |
+
|
57 |
+
if not supported_label:
|
58 |
+
raise ValueError(f"Unknown label in FunctionAnnotation: {fa.label}")
|
59 |
+
|
60 |
+
# Convert function token FunctionAnnotations -> Tensor
|
61 |
+
function_tokens = function_tokens_tokenizer.tokenize(
|
62 |
+
annotations=ft_annotations, seqlen=len(sequence)
|
63 |
+
)
|
64 |
+
function_token_ids = function_tokens_tokenizer.encode(
|
65 |
+
function_tokens, add_special_tokens=add_special_tokens
|
66 |
+
)
|
67 |
+
|
68 |
+
# Convert residue annotation FunctionAnnotations -> Tensor
|
69 |
+
if ra_annotations:
|
70 |
+
descriptions, starts, ends = zip(
|
71 |
+
*[(anot.label, anot.start, anot.end) for anot in ra_annotations]
|
72 |
+
)
|
73 |
+
else:
|
74 |
+
descriptions = starts = ends = None
|
75 |
+
ra_tokens = residue_annotations_tokenizer.tokenize(
|
76 |
+
{
|
77 |
+
"interpro_site_descriptions": descriptions,
|
78 |
+
"interpro_site_starts": starts,
|
79 |
+
"interpro_site_ends": ends,
|
80 |
+
},
|
81 |
+
sequence=sequence,
|
82 |
+
fail_on_mismatch=True,
|
83 |
+
)
|
84 |
+
residue_annotation_ids = residue_annotations_tokenizer.encode(
|
85 |
+
ra_tokens, add_special_tokens=add_special_tokens
|
86 |
+
)
|
87 |
+
|
88 |
+
return function_token_ids, residue_annotation_ids
|
89 |
+
|
90 |
+
|
91 |
+
def decode_function_tokens(
|
92 |
+
function_token_ids: torch.Tensor,
|
93 |
+
function_token_decoder: FunctionTokenDecoder,
|
94 |
+
function_tokens_tokenizer: InterProQuantizedTokenizer,
|
95 |
+
decoder_annotation_threshold: float = 0.1,
|
96 |
+
annotation_min_length: int | None = 5,
|
97 |
+
annotation_gap_merge_max: int | None = 3,
|
98 |
+
) -> list[FunctionAnnotation]:
|
99 |
+
"""Decodes model prediction logits into function predictions.
|
100 |
+
|
101 |
+
Merges function token and residue annotation predictions into a single
|
102 |
+
set of FunctionAnnotation predictions.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
function_token_ids: Tensor <float>[length, depth] of
|
106 |
+
function token ids.
|
107 |
+
residue_annotation_logits: Tensor <float>[length, RA-vocab] of residue
|
108 |
+
annotation binary classification logits.
|
109 |
+
function_tokens_tokenizer: InterPro annotation tokenizer.
|
110 |
+
residue_annotation_threshold: tokenizer of residue annotations.
|
111 |
+
residue_annotation_threshold: predicted probability threshold for emitting
|
112 |
+
a predicted residue annotation.
|
113 |
+
Returns:
|
114 |
+
Predicted function annotations merged from both predictions.
|
115 |
+
"""
|
116 |
+
assert (
|
117 |
+
function_token_ids.ndim == 2
|
118 |
+
), "function_token_ids must be of shape (length, depth)"
|
119 |
+
|
120 |
+
annotations: list[FunctionAnnotation] = []
|
121 |
+
|
122 |
+
# Function Annotations from predicted function tokens.
|
123 |
+
decoded = function_token_decoder.decode(
|
124 |
+
function_token_ids,
|
125 |
+
tokenizer=function_tokens_tokenizer,
|
126 |
+
annotation_threshold=decoder_annotation_threshold,
|
127 |
+
annotation_min_length=annotation_min_length,
|
128 |
+
annotation_gap_merge_max=annotation_gap_merge_max,
|
129 |
+
)
|
130 |
+
|
131 |
+
# Convert predicted InterPro annotation to FunctionAnnotation.
|
132 |
+
annotations.extend(decoded["function_keywords"])
|
133 |
+
for annotation in decoded["interpro_annotations"]:
|
134 |
+
annotation: FunctionAnnotation
|
135 |
+
label = function_tokens_tokenizer.format_annotation(annotation)
|
136 |
+
annotations.append(
|
137 |
+
FunctionAnnotation(label=label, start=annotation.start, end=annotation.end)
|
138 |
+
)
|
139 |
+
|
140 |
+
return annotations
|
141 |
+
|
142 |
+
|
143 |
+
def decode_residue_annotation_tokens(
|
144 |
+
residue_annotations_token_ids: torch.Tensor,
|
145 |
+
residue_annotations_tokenizer: ResidueAnnotationsTokenizer,
|
146 |
+
annotation_min_length: int | None = 5,
|
147 |
+
annotation_gap_merge_max: int | None = 3,
|
148 |
+
) -> list[FunctionAnnotation]:
|
149 |
+
"""Decodes residue annotation tokens into FunctionAnnotations.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
tokens: Tensor <int>[length, MAX_RESIDUE_ANNOTATIONS] of residue annotation tokens.
|
153 |
+
residue_annotations_tokenizer: Tokenizer of residue annotations.
|
154 |
+
threshold: predicted probability threshold for emitting a predicted residue
|
155 |
+
annotation.
|
156 |
+
Returns:
|
157 |
+
Predicted residue annotations.
|
158 |
+
"""
|
159 |
+
assert (
|
160 |
+
residue_annotations_token_ids.ndim == 2
|
161 |
+
), "logits must be of shape (length, MAX_RESIDUE_ANNOTATIONS)"
|
162 |
+
|
163 |
+
annotations: list[FunctionAnnotation] = []
|
164 |
+
|
165 |
+
for depth in range(0, C.MAX_RESIDUE_ANNOTATIONS):
|
166 |
+
token_ids = residue_annotations_token_ids[:, depth]
|
167 |
+
nonzero_indices = torch.nonzero(token_ids).squeeze(dim=1).cpu().numpy()
|
168 |
+
if len(nonzero_indices) == 0:
|
169 |
+
continue
|
170 |
+
for loc in nonzero_indices:
|
171 |
+
vocab_index: int = token_ids[loc].item() # type: ignore
|
172 |
+
label = residue_annotations_tokenizer.vocabulary[vocab_index]
|
173 |
+
if label not in [*residue_annotations_tokenizer.special_tokens, "<none>"]:
|
174 |
+
annotation = FunctionAnnotation(label=label, start=loc, end=loc)
|
175 |
+
annotations.append(annotation)
|
176 |
+
|
177 |
+
annotations = merge_annotations(annotations, merge_gap_max=annotation_gap_merge_max)
|
178 |
+
|
179 |
+
# Drop very small annotations.
|
180 |
+
if annotation_min_length is not None:
|
181 |
+
annotations = [
|
182 |
+
annotation
|
183 |
+
for annotation in annotations
|
184 |
+
if annotation.end - annotation.start + 1 >= annotation_min_length
|
185 |
+
]
|
186 |
+
|
187 |
+
return annotations
|
Dyna-1/esm/utils/function/interpro.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utilities for interacting with InterPro."""
|
2 |
+
|
3 |
+
import itertools
|
4 |
+
import re
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from enum import IntEnum, auto
|
7 |
+
from functools import cached_property
|
8 |
+
|
9 |
+
import networkx as nx
|
10 |
+
import pandas as pd
|
11 |
+
from cloudpathlib import AnyPath
|
12 |
+
|
13 |
+
from esm.utils.constants import esm3 as C
|
14 |
+
from esm.utils.types import PathLike
|
15 |
+
|
16 |
+
|
17 |
+
def parse_go_terms(text: str) -> list[str]:
|
18 |
+
"""Parses GO terms from a string.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
text: String containing GO terms. Example: "GO:0008309, GO:1902267" Note that GO
|
22 |
+
terms have exactly 7 digits.
|
23 |
+
Returns:
|
24 |
+
All GO terms found in the string. Example: ['GO:0008309', 'GO:1902267']
|
25 |
+
"""
|
26 |
+
return re.findall(r"GO:(?:\d{7,})", text)
|
27 |
+
|
28 |
+
|
29 |
+
def _parse_interpro2go(path: PathLike) -> dict[str, list[str]]:
|
30 |
+
"""Parses InterPro2GO file into map.
|
31 |
+
|
32 |
+
NOTE: this file has a very strange, non-standard format.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
path: path to InterPro2GO file from: https://www.ebi.ac.uk/GOA/InterPro2GO
|
36 |
+
Returns:
|
37 |
+
Mapping from InterPro to list of associated GO terms.
|
38 |
+
"""
|
39 |
+
with AnyPath(path).open("r") as f:
|
40 |
+
text = f.read()
|
41 |
+
df = pd.Series(text.split("\n"), name="line").to_frame()
|
42 |
+
df = df[~df.line.str.startswith("!")]
|
43 |
+
df["interpro_id"] = df.line.apply(lambda line: re.findall(r"IPR\d+", line))
|
44 |
+
df["go_ids"] = df.line.apply(parse_go_terms)
|
45 |
+
df = df[df.go_ids.apply(len).gt(0) & df.interpro_id.apply(len).eq(1)]
|
46 |
+
df["interpro_id"] = df["interpro_id"].apply(lambda xs: xs[0]) # type: ignore
|
47 |
+
|
48 |
+
# Group all mappints together into a single map.
|
49 |
+
df = (
|
50 |
+
df.groupby("interpro_id")["go_ids"] # type: ignore
|
51 |
+
.apply(lambda group: list(itertools.chain.from_iterable(group)))
|
52 |
+
.reset_index()
|
53 |
+
)
|
54 |
+
return dict(zip(df.interpro_id, df.go_ids)) # type: ignore
|
55 |
+
|
56 |
+
|
57 |
+
class InterProEntryType(IntEnum):
|
58 |
+
"""InterPro types and representation counts:
|
59 |
+
|
60 |
+
Family 21,942
|
61 |
+
Domain 14,053
|
62 |
+
Homologous_superfamily 3,446
|
63 |
+
Conserved_site 728
|
64 |
+
Repeat 374
|
65 |
+
Active_site 133
|
66 |
+
Binding_site 75
|
67 |
+
PTM 17
|
68 |
+
"""
|
69 |
+
|
70 |
+
ACTIVE_SITE = 0
|
71 |
+
BINDING_SITE = auto()
|
72 |
+
CONSERVED_SITE = auto()
|
73 |
+
DOMAIN = auto()
|
74 |
+
FAMILY = auto()
|
75 |
+
HOMOLOGOUS_SUPERFAMILY = auto()
|
76 |
+
PTM = auto()
|
77 |
+
REPEAT = auto()
|
78 |
+
UNKNOWN = auto()
|
79 |
+
|
80 |
+
|
81 |
+
@dataclass
|
82 |
+
class InterProEntry:
|
83 |
+
"""Represents an InterPro entry."""
|
84 |
+
|
85 |
+
id: str # Example: IPR000006
|
86 |
+
type: InterProEntryType
|
87 |
+
name: str # Example: "Metallothionein, vertebrate"
|
88 |
+
description: str | None = None
|
89 |
+
|
90 |
+
|
91 |
+
class InterPro:
|
92 |
+
"""Convenience class interacting with InterPro ontology/data."""
|
93 |
+
|
94 |
+
def __init__(
|
95 |
+
self,
|
96 |
+
entries_path: PathLike | None = None,
|
97 |
+
hierarchy_path: PathLike | None = None,
|
98 |
+
interpro2go_path: PathLike | None = None,
|
99 |
+
):
|
100 |
+
"""Constructs interface to query InterPro entries."""
|
101 |
+
|
102 |
+
def default(x, d):
|
103 |
+
return x if x is not None else d
|
104 |
+
|
105 |
+
self.entries_path = default(entries_path, C.INTERPRO_ENTRY)
|
106 |
+
self.hierarchy_graph_path = default(hierarchy_path, C.INTERPRO_HIERARCHY)
|
107 |
+
self.interpro2go_path = default(interpro2go_path, C.INTERPRO2GO)
|
108 |
+
|
109 |
+
@cached_property
|
110 |
+
def interpro2go(self) -> dict[str, list[str]]:
|
111 |
+
"""Reads the InterPro to GO term mapping."""
|
112 |
+
assert self.interpro2go_path is not None
|
113 |
+
return _parse_interpro2go(self.interpro2go_path)
|
114 |
+
|
115 |
+
@cached_property
|
116 |
+
def entries_frame(self) -> pd.DataFrame:
|
117 |
+
"""Loads full InterPro entry set as a DataFrame.
|
118 |
+
|
119 |
+
Colums are
|
120 |
+
- "id": str interpro accession /id as
|
121 |
+
- "type": InterProEntryType representing the type of annotation.
|
122 |
+
- "name": Short name of the entry.
|
123 |
+
"""
|
124 |
+
with AnyPath(self.entries_path).open("r") as f:
|
125 |
+
df = pd.read_csv(f, sep="\t")
|
126 |
+
assert all(
|
127 |
+
col in df.columns for col in ["ENTRY_AC", "ENTRY_TYPE", "ENTRY_NAME"]
|
128 |
+
)
|
129 |
+
df.rename(
|
130 |
+
columns={"ENTRY_AC": "id", "ENTRY_TYPE": "type", "ENTRY_NAME": "name"},
|
131 |
+
inplace=True,
|
132 |
+
)
|
133 |
+
df["type"] = df.type.str.upper().apply(
|
134 |
+
lambda type_name: InterProEntryType[type_name]
|
135 |
+
)
|
136 |
+
return df
|
137 |
+
|
138 |
+
@cached_property
|
139 |
+
def entries(self) -> dict[str, InterProEntry]:
|
140 |
+
"""Returns all InterPro entries."""
|
141 |
+
return {
|
142 |
+
row.id: InterProEntry( # type: ignore
|
143 |
+
id=row.id, # type: ignore
|
144 |
+
type=row.type, # type: ignore
|
145 |
+
name=row.name, # type: ignore
|
146 |
+
)
|
147 |
+
for row in self.entries_frame.itertuples()
|
148 |
+
}
|
149 |
+
|
150 |
+
def lookup_name(self, interpro_id: str) -> str | None:
|
151 |
+
"""Short name / title for an interpro id."""
|
152 |
+
if interpro_id not in self.entries:
|
153 |
+
return None
|
154 |
+
return self.entries[interpro_id].name
|
155 |
+
|
156 |
+
def lookup_entry_type(self, interpro_id: str) -> InterProEntryType:
|
157 |
+
"""Looks up entry-type for an interpro id."""
|
158 |
+
if interpro_id in self.entries:
|
159 |
+
return self.entries[interpro_id].type
|
160 |
+
else:
|
161 |
+
return InterProEntryType.UNKNOWN
|
162 |
+
|
163 |
+
@cached_property
|
164 |
+
def graph(self) -> nx.DiGraph:
|
165 |
+
"""Reads the InterPro hierarchy of InterPro."""
|
166 |
+
graph = nx.DiGraph()
|
167 |
+
with AnyPath(self.hierarchy_graph_path).open("r") as f:
|
168 |
+
parents = []
|
169 |
+
for line in f:
|
170 |
+
ipr = line.split("::", maxsplit=1)[0]
|
171 |
+
ipr_strip = ipr.lstrip("-")
|
172 |
+
level = (len(ipr) - len(ipr_strip)) // 2
|
173 |
+
parents = parents[:level]
|
174 |
+
graph.add_node(ipr_strip)
|
175 |
+
if parents:
|
176 |
+
graph.add_edge(ipr_strip, parents[-1])
|
177 |
+
parents.append(ipr_strip)
|
178 |
+
return graph
|
Dyna-1/esm/utils/function/lsh.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from cloudpathlib import AnyPath
|
3 |
+
|
4 |
+
from esm.utils.types import PathLike
|
5 |
+
|
6 |
+
|
7 |
+
class LSHTable:
|
8 |
+
def __init__(self, n_bits: int, dim: int, hyperplanes: np.ndarray | None = None):
|
9 |
+
if hyperplanes is None:
|
10 |
+
hyperplanes = np.random.randn(n_bits, dim)
|
11 |
+
hyperplanes = hyperplanes / np.linalg.norm(
|
12 |
+
hyperplanes, axis=-1, keepdims=True
|
13 |
+
)
|
14 |
+
else:
|
15 |
+
assert hyperplanes.shape == (n_bits, dim), (
|
16 |
+
hyperplanes.shape,
|
17 |
+
(n_bits, dim),
|
18 |
+
)
|
19 |
+
assert hyperplanes is not None
|
20 |
+
self.hyperplanes: np.ndarray = hyperplanes
|
21 |
+
self.values = 1 << np.arange(n_bits)
|
22 |
+
|
23 |
+
def __call__(self, array, tokenize: bool = True):
|
24 |
+
similarity = self.hyperplanes @ array.T
|
25 |
+
bits = np.where(similarity >= 0, 1, 0)
|
26 |
+
if tokenize:
|
27 |
+
tokens = bits.T @ self.values
|
28 |
+
return tokens
|
29 |
+
else:
|
30 |
+
return bits.T
|
31 |
+
|
32 |
+
|
33 |
+
class LSHTokenized:
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
n_bits: int,
|
37 |
+
dim: int,
|
38 |
+
num_tables: int = 1,
|
39 |
+
filepath: PathLike | None = None,
|
40 |
+
allow_create_hyperplanes: bool = False, # set this if you want the lsh to allow creation of hyperplanes
|
41 |
+
):
|
42 |
+
table_hyperplanes = None
|
43 |
+
if filepath is not None:
|
44 |
+
filepath = AnyPath(filepath)
|
45 |
+
if not filepath.exists():
|
46 |
+
raise FileNotFoundError(filepath)
|
47 |
+
table_hyperplanes = np.load(filepath) # type: ignore
|
48 |
+
for i in range(num_tables):
|
49 |
+
assert str(i) in table_hyperplanes, f"Missing hyperplane for table {i}"
|
50 |
+
elif not allow_create_hyperplanes:
|
51 |
+
raise RuntimeError(
|
52 |
+
"Not allowed to create hyperplanes but no filepath provided"
|
53 |
+
)
|
54 |
+
|
55 |
+
self.tables = [
|
56 |
+
LSHTable(
|
57 |
+
n_bits,
|
58 |
+
dim,
|
59 |
+
table_hyperplanes[str(i)] if table_hyperplanes is not None else None,
|
60 |
+
)
|
61 |
+
for i in range(num_tables)
|
62 |
+
]
|
63 |
+
|
64 |
+
def write_hyperplanes(self, filepath: PathLike):
|
65 |
+
hyperplanes: dict[str, np.ndarray] = { # type: ignore
|
66 |
+
str(i): table.hyperplanes for i, table in enumerate(self.tables)
|
67 |
+
}
|
68 |
+
np.savez(filepath, **hyperplanes)
|
69 |
+
|
70 |
+
def __call__(self, array):
|
71 |
+
tokens = np.stack([table(array) for table in self.tables], 1)
|
72 |
+
return tokens
|
73 |
+
|
74 |
+
|
75 |
+
class LSHBitstream:
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
n_bits: int,
|
79 |
+
dim: int,
|
80 |
+
filepath: PathLike | None = None,
|
81 |
+
allow_create_hyperplanes: bool = False, # set this if you want the lsh to allow creation of hyperplanes
|
82 |
+
):
|
83 |
+
table_hyperplanes = None
|
84 |
+
if filepath is not None:
|
85 |
+
filepath = AnyPath(filepath)
|
86 |
+
if not filepath.exists():
|
87 |
+
raise FileNotFoundError(filepath)
|
88 |
+
table_hyperplanes = np.load(filepath)
|
89 |
+
elif not allow_create_hyperplanes:
|
90 |
+
raise RuntimeError(
|
91 |
+
"Not allowed to create hyperplanes but no filepath provided"
|
92 |
+
)
|
93 |
+
|
94 |
+
self.table = LSHTable(
|
95 |
+
n_bits, dim, table_hyperplanes if table_hyperplanes is not None else None
|
96 |
+
)
|
97 |
+
|
98 |
+
def write_hyperplanes(self, filepath: PathLike):
|
99 |
+
np.save(filepath, self.table.hyperplanes)
|
100 |
+
|
101 |
+
def __call__(self, array):
|
102 |
+
return self.table(array, tokenize=False)
|