gelnesr commited on
Commit
74bc48e
·
0 Parent(s):

Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +35 -0
  2. Dyna-1/LICENSE.txt +73 -0
  3. Dyna-1/README.md +115 -0
  4. Dyna-1/configs/af2.yml +31 -0
  5. Dyna-1/configs/baseline.yml +28 -0
  6. Dyna-1/configs/esm2.yml +28 -0
  7. Dyna-1/configs/esm3.yml +28 -0
  8. Dyna-1/data/dataloader.py +341 -0
  9. Dyna-1/data/vocab.py +104 -0
  10. Dyna-1/esm/__init__.py +2 -0
  11. Dyna-1/esm/data/ParentChildTreeFile.txt +0 -0
  12. Dyna-1/esm/data/entry_list_safety_29026.list +0 -0
  13. Dyna-1/esm/data/interpro_29026_to_keywords_58641.csv +0 -0
  14. Dyna-1/esm/data/keyword_idf_safety_filtered_58641.npy +0 -0
  15. Dyna-1/esm/data/keyword_vocabulary_safety_filtered_58641.txt +0 -0
  16. Dyna-1/esm/layers/attention.py +76 -0
  17. Dyna-1/esm/layers/blocks.py +153 -0
  18. Dyna-1/esm/layers/codebook.py +88 -0
  19. Dyna-1/esm/layers/ffn.py +29 -0
  20. Dyna-1/esm/layers/geom_attention.py +149 -0
  21. Dyna-1/esm/layers/regression_head.py +22 -0
  22. Dyna-1/esm/layers/rotary.py +221 -0
  23. Dyna-1/esm/layers/structure_proj.py +66 -0
  24. Dyna-1/esm/layers/transformer_stack.py +93 -0
  25. Dyna-1/esm/models/esm3.py +606 -0
  26. Dyna-1/esm/models/esmc.py +164 -0
  27. Dyna-1/esm/models/function_decoder.py +306 -0
  28. Dyna-1/esm/models/vqvae.py +440 -0
  29. Dyna-1/esm/pretrained.py +132 -0
  30. Dyna-1/esm/sdk/__init__.py +22 -0
  31. Dyna-1/esm/sdk/api.py +445 -0
  32. Dyna-1/esm/sdk/forge.py +580 -0
  33. Dyna-1/esm/sdk/sagemaker.py +110 -0
  34. Dyna-1/esm/tokenization/__init__.py +69 -0
  35. Dyna-1/esm/tokenization/function_tokenizer.py +429 -0
  36. Dyna-1/esm/tokenization/residue_tokenizer.py +236 -0
  37. Dyna-1/esm/tokenization/sasa_tokenizer.py +153 -0
  38. Dyna-1/esm/tokenization/sequence_tokenizer.py +89 -0
  39. Dyna-1/esm/tokenization/ss_tokenizer.py +125 -0
  40. Dyna-1/esm/tokenization/structure_tokenizer.py +83 -0
  41. Dyna-1/esm/tokenization/tokenizer_base.py +44 -0
  42. Dyna-1/esm/utils/constants/api.py +5 -0
  43. Dyna-1/esm/utils/constants/esm3.py +130 -0
  44. Dyna-1/esm/utils/constants/models.py +25 -0
  45. Dyna-1/esm/utils/constants/physics.py +5 -0
  46. Dyna-1/esm/utils/decoding.py +244 -0
  47. Dyna-1/esm/utils/encoding.py +246 -0
  48. Dyna-1/esm/utils/function/encode_decode.py +187 -0
  49. Dyna-1/esm/utils/function/interpro.py +178 -0
  50. Dyna-1/esm/utils/function/lsh.py +102 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
Dyna-1/LICENSE.txt ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Non-Commercial License Agreement.
2
+
3
+ Copyright (c) 2025 Brandeis University, Gina El Nesr, Hannah Wayment-Steele.
4
+
5
+ IMPORTANT: PLEASE READ THIS NON-COMMERCIAL LICENSE AGREEMENT ("AGREEMENT") CAREFULLY BEFORE USING THE SOFTWARE. BY DOWNLOADING, INSTALLING, OR USING THE SOFTWARE, YOU (THE "LICENSEE") AGREE TO BE BOUND BY THE TERMS AND CONDITIONS OF THIS AGREEMENT. IF YOU DO NOT AGREE TO THE TERMS AND CONDITIONS OF THIS AGREEMENT, DO NOT DOWNLOAD, INSTALL, OR USE THE SOFTWARE.
6
+
7
+ WHEREAS, Brandeis University (the "Licensor") are the licensor of the Dyna-1 model (the "Model");
8
+
9
+ DEFINITIONS, in addition to other terms defined elsewhere have the following meanings:
10
+
11
+ THE MODEL (the "Model" or the "Software") means the Dyna-1 released model as a combination of the Dyna-1 Model Code and the Dyna-1 Model Weights and all software, algorithms, machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code, and other elements of the foregoing distributed and made available by the Licensor on the GitHub Page (https://github.com/WaymentSteeleLab/Dyna-1) or other code expressly listed within the GitHub Page including but not limited to Google Colab Notebooks, each as may be updated and amended from time to time, whether in source or object form. Any instances of any source code or weights that depend explicitly on ESM-3 Models are licensed under the EvolutionaryScale Cambrian Non-Commercial License Agreement and subject to the terms that are more restrictive.
12
+
13
+ COMMERCIAL USE means any use or activity intended for or directed towards commercial advantage or monetary compensation, including without limitation, the development of any product, method, or service intended to be sold or made available for a fee.
14
+
15
+ COMMERCIAL ENTITY means any individual, entity, or organization involved in any capacity in Commercial Use. For the purposes of this Agreement, references to a Commercial Entity expressly exclude any universities, non-profit organizations, not-for-profit entities, research institutions, educational bodies, and government bodies.
16
+
17
+ NON-COMMERCIAL USE means use not leading to or directed towards commercial advantage or monetary compensations, or the facilitation of development of any product, method, or service to be sold or made available for a fee. Non-commercial use excludes any activities conducted by commercial entities, even if those activities themselves do not directly generate revenue. Examples of non-commercial use include academic research, personal projects, and use by non-profit organizations.
18
+
19
+ DERIVATIVE WORKS means any work, in source or object form, that is based on or derived from the Model.
20
+
21
+ OUTPUT DERIVATIVES means any outputs resulting from use of the Model. This includes, but is not limited to, any predictions that result from the Model.
22
+
23
+ NOW, THEREFORE, the Licensor and the Licensee hereto agree as follows:
24
+
25
+ GRANT OF LICENSE.
26
+
27
+ Non-Commercial Use. The Model is provided under the MIT License for non-commercial use. The MIT License is included below.
28
+
29
+ Commercial Use. If the Licensee is or represents a Commercial Entity and wishes to use the Model for Commercial Use, the Licensee must first contact the Licensor for a commercial license. Examples of Commercial Use include but are not limited to the:
30
+
31
+ Selling the Model, either as is or as part of a larger software package;
32
+
33
+ Using the Model as an interface or component of a service where customers pay to use it;
34
+
35
+ Providing the Model, either as is or as part of a larger software package;
36
+
37
+ Incorporating the Model, in part or in whole, into a any commercial operation or workflow that may lead to commercial use of a product, method, or service;
38
+
39
+ Use of any derivative works or output derivatives towards any commercial operation or workflow that may lead to commercial use.
40
+
41
+ RESTRICTIONS. The Licensee shall not:
42
+
43
+ Use, modify, or distribute the Model for commercial purposes without receiving a commercial license.
44
+
45
+ Sublicense, assign, or transfer the Model without Licensor's prior written consent.
46
+
47
+ Use the Model in any manner that infringes upon any third-party rights.
48
+
49
+ Use the Model in any way that is not expressly permitted by this Agreement.
50
+
51
+ OWNERSHIP. Licensor retains all right, title, and interest in and to the Model, including all intellectual property rights therein. No rights are granted to Licensee other than those expressly set forth in this Agreement.
52
+
53
+ TERM AND TERMINATION. This Agreement shall commence upon Licensee's initial use of the Model and continue unless terminated by either party upon providing written notice to the other party. Upon termination, Licensee shall immediately cease using the Model and destroy all copies thereof.
54
+
55
+ LIABILITY. Licensor shall not be liable for any direct, indirect, incidental, special, consequential, or exemplary damages, including but not limited to damages for loss of profits, goodwill, use, data, or other intangible losses, resulting from the use of the Model.
56
+
57
+ GOVERNING LAW. This Agreement shall be governed by and construed in accordance with the laws of the United States of America.
58
+
59
+ MIT LICENSE
60
+
61
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
62
+
63
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
64
+
65
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
66
+
67
+ By using the Software, Licensee acknowledges that they have read this Agreement, understood it, and agreed to be bound by its terms and conditions.
68
+
69
+ CONTACT INFORMATION.
70
+ For commercial license inquiries, please contact: Brandeis University Office of Technology Licensing at [email protected]
71
+ For technical inquiries, please contact: [email protected], [email protected]
72
+
73
+ [END OF AGREEMENT]
Dyna-1/README.md ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dyna-1
2
+ [![Requires Python 3.10+](https://img.shields.io/badge/Python-3.10+-blue.svg?logo=python&logoColor=white)](https://python.org/downloads)
3
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/WaymentSteeleLab/Dyna-1/blob/main/colab/Dyna_1.ipynb)
4
+
5
+ ![image](assets/dyna1.png)
6
+
7
+ Dyna-1 is a model introduced in our paper, ["Learning millisecond protein dynamics from what is missing in NMR spectra"](https://www.biorxiv.org/content/10.1101/2025.03.19.642801v1).
8
+
9
+ Given a sequence and/or structure, Dyna-1 will predict the probability that each residue experiences micro-millisecond motions.
10
+
11
+ Dyna-1 was achieved using the `esm3-sm-open-v1` weights from ESM-3. Inference with this model is subject to the EvolutionaryScale Cambrian Non-Commercial License Agreement of the ESM-3 Model and requires read permission of the weights found [here](https://huggingface.co/EvolutionaryScale/esm3-sm-open-v1). We also make available an alternate version of Dyna-1 that uses ESM-2 embeddings; use of this model is subject to a Non-Commercial License Agreement.
12
+
13
+ To make Dyna-1 readily accessible for research purposes, we also provide a [Google Colab](https://colab.research.google.com/github/WaymentSteeleLab/Dyna-1/blob/main/colab/Dyna_1.ipynb).
14
+
15
+ We provide the curated datasets used to evaluate Dyna-1: 133 curated R1/R2/NOE datasets "RelaxDB" and 10 relaxation-dispersion Carr-Purcell-Meiboom-Gill datasets "RelaxDB-CPMG".
16
+
17
+ In this repository we provide:
18
+ * [Installation](#installation)
19
+ * [Inference code for Dyna-1](#inference)
20
+ * [PyMol visualization](#visualization)
21
+ * [RelaxDB and RelaxDB-CPMG datasets](#datasets)
22
+ * [Training code for Dyna-1](#training)
23
+ * [Citation](#citation)
24
+ * [Acknowledgements](#acknowledgement)
25
+
26
+ If you have any questions not covered here, please create an issue and we will get back to you ASAP.
27
+
28
+ # Installation
29
+ To run the scripts in this repository, we recommend using a conda environment. First, clone this repository. Then navigate to the root directory and run:
30
+ ```
31
+ conda create -n dyna1 python=3.11
32
+ conda activate dyna1
33
+ ```
34
+ This package requires PyTorch, ideally with GPU support. For more information, follow instructions from https://pytorch.org/get-started/locally/ for your system and CUDA version. We used PyTorch 2.5.0 with CUDA 12.4. To install all of the requirements:
35
+ ```
36
+ pip install -r requirements.txt
37
+ ```
38
+ Then, download the model weights and upload them to the `model/weights` folder. The weights can be found on 🤗HuggingFace at <a href='https://huggingface.co/gelnesr/Dyna-1'>gelnesr/Dyna-1</a>. More information on how to download them can be found <a href='https://github.com/gelnesr/Dyna-1-public/blob/main/model/weights/README.md'>here</a>.
39
+
40
+ # Inference
41
+
42
+ The best-performing Dyna-1 is based on ESM-3. To run this version, you will have to request access to the ESM-3 `esm3-sm-open-v1` weights at HuggingFace [here](https://huggingface.co/EvolutionaryScale/esm3-sm-open-v1). Follow the steps to agree to their License terms and receive your access token to the model weights.
43
+
44
+ > [!NOTE]
45
+ > If this is your first time requesting access to the ESM-3 weights, you may need to set up your access token. For more information on how to set up an SSH token, please consult <a href='https://huggingface.co/docs/hub/en/security-git-ssh'>this</a> tutorial. Alternatively, you can use the huggingface login prompt, which will prompt you for the access token each time you re-instantiate Dyna-1. This can be configured by adding the following code to the inference script: `from huggingface_hub import login; login()`
46
+
47
+ To run inference using our best-performing model, run:
48
+
49
+ ```
50
+ python dyna1.py --pdb <PDB CODE or PATH> --chain <CHAIN> --name <NAME> --use_pdb_seq --write_to_pdb
51
+ ```
52
+
53
+ We provide three options for running inference: sequence and structure input (best performance!), sequence only, or structure only. Additionally, we make it possible to test different sequences for the same backbone. Examples on how to run each of these modes can be found in the `scripts/` folder.
54
+
55
+ To output the probabilities to the input structure, make sure to pass the `--write_to_pdb` flag.
56
+
57
+ Alternatively, we also provide a version of Dyna-1 based on ESM-2. To run inference using this version of the model, run:
58
+
59
+ ```
60
+ python dyna1-esm2.py --sequence <SEQUENCE> --name <NAME>
61
+ ```
62
+
63
+ # Visualization
64
+
65
+ We visualize probabilities of exchange on protein structures using [PyMol](https://www.pymol.org). To re-create the putty visualization on your protein, import the pdb file into PyMol and copy-paste the following commands into the PyMol command line:
66
+
67
+ ```
68
+ cartoon putty; set cartoon_putty_transform, 6; set cartoon_putty_radius, 0.25; set cartoon_putty_range, 0.1; set cartoon_putty_scale_max, 10
69
+ ```
70
+
71
+ Annotated:
72
+ ```
73
+ cartoon putty
74
+ set cartoon_putty_transform, 6 #scaled linear transformation
75
+ set cartoon_putty_radius, 0.25 # min radius for p=0
76
+ set cartoon_putty_range, 0.1 # min_radius / max_radius, sets max_radius=2.5
77
+ set cartoon_putty_scale_max, 10 #max_radius / min_radius
78
+ ```
79
+
80
+ # Datasets
81
+
82
+ *RelaxDB* contains 133 R1/R2/NOE datasets curated from the [BMRB](https://bmrb.io/) and from literature.
83
+
84
+ *RelaxDB-CPMG* contains motion labels derived from 10 CPMG relaxation-dispersion datasets curated from literature.
85
+
86
+ These datasets are made available on 🤗HuggingFace at <a href='https://huggingface.co/datasets/gelnesr/RelaxDB'>datasets/gelnesr/RelaxDB</a>.
87
+
88
+ In this repo, you can find:
89
+ - data formatted for input into Dyna-1 is in `data/RelaxDB_pkls_22jan2025.zip`
90
+ - datasets in json format is in `data/RelaxDB_datasets/`
91
+ - demo notebooks for demo notebooks for visualizing and using datasets to evaluate model outputs in `analysis/`
92
+
93
+ # Training
94
+
95
+ Training code will be made available upon journal publication.
96
+
97
+ # Citation
98
+
99
+ If you are using our code, datasets, or model, please use the following citation:
100
+ ```bibtex
101
+ @article {Dyna-1,
102
+ author = {Wayment-Steele, Hannah K. and El Nesr, Gina and Hettiarachchi, Ramith and Kariyawasam, Hasindu and Ovchinnikov, Sergey and Kern, Dorothee},
103
+ title = {Learning millisecond protein dynamics from what is missing in NMR spectra},
104
+ year = {2025},
105
+ doi = {10.1101/2025.03.19.642801},
106
+ journal = {bioRxiv}
107
+ }
108
+ ```
109
+ # Acknowledgements
110
+
111
+ We would like to acknowledge the Evolutionary Scale Team for their contributions to the field with ESM-3. The code in `esm` is imported from `evolutionaryscale/esm` with all modifications identified and includes the associated LICENSE terms for the ESM-3 model.
112
+
113
+ We would also like to acknowledge the FAIR Team for their contributions to the field with ESM-2. The ESM-2 model is called using the HuggingFace API call.
114
+
115
+ We thank Katie Henzler-Wildman, Magnus Wolf-Watz, Elan Eisenmesser, J. Patrick Loria, Marcellus Ubbelink, George Lisi, Sam Butcher, and Nicolas Doucet for sharing data. We thank Martin Stone for sharing the Indiana Dynamics Database data his group curated in 2000.
Dyna-1/configs/af2.yml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: 'af2'
2
+ train:
3
+ epochs: 500
4
+ batchsize: 4
5
+ accum_steps: 4
6
+ lr: 0.00001
7
+ dropout: 0.1
8
+ num_workers: 1
9
+ model:
10
+ hidden_size: 128
11
+ res_count: 32
12
+ length: 400
13
+ nheads: 8
14
+ nlayers: 12
15
+ dir:
16
+ save_dir: '/path/to/save/dir'
17
+ data:
18
+ pair_rep: None
19
+ sample_clusters: True
20
+ relaxdb:
21
+ split: 'relaxdb'
22
+ crop_len: 400
23
+ type: 'rex'
24
+ cpmg:
25
+ split: 'relaxdb-cpmg'
26
+ crop_len: 400
27
+ type: 'cpmg'
28
+ wandb:
29
+ project: "New-Project"
30
+ team: 'my-project-team'
31
+ dir: '/wandb/log/dir'
Dyna-1/configs/baseline.yml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: 'baseline'
2
+ train:
3
+ epochs: 500
4
+ batchsize: 16
5
+ accum_steps: 1
6
+ lr: 0.000001
7
+ dropout: 0.1
8
+ num_workers: 2
9
+ model:
10
+ nheads: 6
11
+ nlayers: 10
12
+ dir:
13
+ save_dir: '/path/to/save/dir'
14
+ data:
15
+ pair_rep: None
16
+ sample_clusters: True
17
+ relaxdb:
18
+ split: 'relaxdb'
19
+ crop_len: 367
20
+ type: 'rex'
21
+ cpmg:
22
+ split: 'relaxdb-cpmg'
23
+ crop_len: 367
24
+ type: 'cpmg'
25
+ wandb:
26
+ project: "New-Project"
27
+ team: 'my-project-team'
28
+ dir: '/wandb/log/dir'
Dyna-1/configs/esm2.yml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: 'esm2_t30_150M_UR50D'
2
+ train:
3
+ epochs: 500
4
+ batchsize: 16
5
+ accum_steps: 1
6
+ lr: 0.000001
7
+ dropout: 0.1
8
+ num_workers: 2
9
+ model:
10
+ nheads: 8
11
+ nlayers: 12
12
+ dir:
13
+ save_dir: '/path/to/save/dir'
14
+ data:
15
+ pair_rep: None
16
+ sample_clusters: True
17
+ relaxdb:
18
+ split: 'relaxdb'
19
+ crop_len: 367
20
+ type: 'rex'
21
+ cpmg:
22
+ split: 'relaxdb-cpmg'
23
+ crop_len: 367
24
+ type: 'cpmg'
25
+ wandb:
26
+ project: "New-Project"
27
+ team: 'my-project-team'
28
+ dir: '/wandb/log/dir'
Dyna-1/configs/esm3.yml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: 'esm3'
2
+ train:
3
+ epochs: 500
4
+ batchsize: 16
5
+ accum_steps: 1
6
+ lr: 0.000001
7
+ dropout: 0.1
8
+ num_workers: 2
9
+ model:
10
+ nheads: 6
11
+ nlayers: 12
12
+ dir:
13
+ save_dir: '/path/to/save/dir'
14
+ data:
15
+ pair_rep: None
16
+ sample_clusters: True
17
+ relaxdb:
18
+ split: 'relaxdb'
19
+ crop_len: 367
20
+ type: 'rex'
21
+ cpmg:
22
+ split: 'relaxdb-cpmg'
23
+ crop_len: 367
24
+ type: 'cpmg'
25
+ wandb:
26
+ project: "New-Project"
27
+ team: 'my-project-team'
28
+ dir: '/wandb/log/dir'
Dyna-1/data/dataloader.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import warnings
3
+ import random
4
+ import torch
5
+ import os
6
+ import pickle
7
+ import data.vocab as vocab
8
+ import pandas as pd
9
+ from typing import Tuple, List, Any
10
+ from esm.sdk.api import ESMProtein
11
+ from transformers import AutoTokenizer
12
+ import utils
13
+
14
+ class DynaData(torch.utils.data.Dataset):
15
+ """
16
+ For each protein, we use a pkl file that contains:
17
+ seq: The domain sequence, stored as an L-length string
18
+ assns: string containing labels of dynamics type
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ split,
24
+ type = 'missing',
25
+ sample_clusters = False,
26
+ cluster_file = None,
27
+ crop_len = 300,
28
+ missing_only = False,
29
+ rex_only = False,
30
+ unsuppressed = False,
31
+ method = None,
32
+ pair_rep = None,
33
+ return_dssp = False
34
+ ):
35
+ super().__init__()
36
+
37
+ self.return_dssp = return_dssp
38
+ self.crop_len = crop_len
39
+ self.sample_clusters = sample_clusters
40
+ self.label_tokenizer = vocab.label_tokenizer(type = type,
41
+ missing_only = missing_only,
42
+ rex_only = rex_only,
43
+ unsuppressed = unsuppressed)
44
+
45
+ # tokenization is the same for all ESM models, use the lightest one
46
+ self.tokenizer = AutoTokenizer.from_pretrained(f"facebook/esm2_t6_8M_UR50D")
47
+ self.proline = self.tokenizer.get_vocab()['P']
48
+
49
+ self.method = method[0]
50
+ self.model = method[1]
51
+
52
+ if isinstance(split, str):
53
+ split = [split]
54
+
55
+ self.all_names, self.names = [], []
56
+
57
+ # read in all pdb names
58
+ for fil in split:
59
+ filename = f'data/split_files/{fil}.txt'
60
+ with open(filename,'r') as f:
61
+ self.all_names.extend(f.read().splitlines())
62
+
63
+ # set up cluster sampling
64
+ if self.sample_clusters:
65
+ self.cluster_info = pd.read_csv(f'data/{cluster_file}.tsv', sep='\t')
66
+ self.cluster_info['cluster'] = self.cluster_info.apply(lambda row: row['cluster'], axis=1)
67
+ for nm in self.all_names:
68
+ subset = self.cluster_info.loc[self.cluster_info.entry_ID==nm]
69
+ if len(subset) == 0:
70
+ print('NO!', nm)
71
+ cluster_ind = subset['cluster'].iloc[0]
72
+ if cluster_ind not in self.names:
73
+ self.names.append(cluster_ind)
74
+ else:
75
+ self.names = self.all_names
76
+
77
+ self.pair_rep_dir = pair_rep
78
+
79
+ def __len__(self):
80
+ return len(self.names)
81
+
82
+ def __baseline_get_item__(self, name, obj, crop_start):
83
+ if crop_start > -1:
84
+ sequence = obj['sequence'][crop_start:crop_start+self.crop_len]
85
+ else:
86
+ sequence = obj['sequence'][:self.crop_len]
87
+
88
+ sequence_tokens = self.tokenizer.encode(sequence,
89
+ add_special_tokens=False,
90
+ padding='max_length',
91
+ max_length=self.crop_len,
92
+ return_tensors='np').squeeze()
93
+
94
+ # set mask to 1 for length of seq, padded tokens then are 0
95
+ eval_mask = np.zeros_like(sequence_tokens)
96
+ eval_mask[:len(sequence)] = 1
97
+
98
+ sequence_id = sequence_tokens != 0
99
+
100
+ # mask prolines in eval
101
+ eval_mask[sequence_tokens==self.proline] = 0
102
+
103
+ return sequence_tokens, sequence_id, eval_mask
104
+
105
+ def __af2_get_item__(self, name, obj, crop_start):
106
+ """
107
+ Prepares input for the AF2-pair model
108
+ """
109
+
110
+ pair_rep = np.load(f"{self.pair_rep_dir}/{name}.npy")
111
+ labels, seq = obj['label'], obj['sequence']
112
+
113
+ if crop_start > -1:
114
+ pair_rep = pair_rep[crop_start:crop_start+self.crop_len,
115
+ crop_start:crop_start+self.crop_len, :]
116
+ labels = labels[crop_start:crop_start+self.crop_len]
117
+ seq = seq[crop_start:crop_start+self.crop_len]
118
+
119
+ eval_mask = np.zeros((pair_rep.shape[0],))
120
+
121
+ prolines = [i for i, aa in enumerate(seq) if aa == 'P']
122
+ eval_mask[:len(labels)] = 1
123
+
124
+ sequence_id = eval_mask != 0
125
+ eval_mask[prolines] = 0
126
+ x = pair_rep.shape[0]
127
+
128
+ eval_mask = np.pad(eval_mask, (0, self.crop_len - len(eval_mask)), mode='constant')
129
+ sequence_id = np.pad(sequence_id, (0, self.crop_len - len(sequence_id)), mode='constant')
130
+ if x < self.crop_len:
131
+ pair_rep = np.pad(pair_rep, ((0, self.crop_len - x), (0, self.crop_len - x), (0, 0)), mode='constant', constant_values=0)
132
+
133
+ return pair_rep, sequence_id, eval_mask
134
+
135
+ def __esm3_get_item__(self, name, crop_start, data_path = 'esm3_data/'):
136
+ """
137
+ Prepares input for the ESM3 model
138
+ """
139
+ pkl_fname = os.path.join(data_path, f"{name}.pkl")
140
+
141
+ try:
142
+ with open(pkl_fname, "rb") as f:
143
+ esm_data = pickle.load(f)
144
+ except:
145
+ print(f'writing pkl for {name} {crop_start}')
146
+ pdb_path = f'pdbs/{name}.pdb'
147
+ protein = ESMProtein.from_pdb(pdb_path)
148
+
149
+ self.model.eval()
150
+ encoder = self.model.model.encode(protein)
151
+ self.model.train()
152
+
153
+ seq = encoder.sequence.cpu().detach()[1:-1][:700]
154
+ struct = encoder.structure.cpu().detach()[1:-1][:700]
155
+
156
+ sequence_tokens = np.full(700, 1, dtype=np.int32) ## sequence pad token is 1
157
+ structure_tokens = np.full(700, 4099, dtype=np.int32) ## structure pad token is 4099
158
+
159
+ sequence_tokens[:len(seq)] = seq
160
+ structure_tokens[:len(struct)] = struct
161
+
162
+ sequence_id = sequence_tokens != 1
163
+
164
+ obj ={'name': name, 'len': len(seq), 'seq_tokens': sequence_tokens,
165
+ 'struct_tokens': structure_tokens, 'sequence_id': sequence_id}
166
+
167
+ with open(pkl_fname, 'wb') as f:
168
+ pickle.dump(obj, f)
169
+
170
+ with open(pkl_fname, "rb") as f:
171
+ esm_data = pickle.load(f)
172
+
173
+ if crop_start > -1:
174
+ sequence_tokens = esm_data['seq_tokens'][crop_start:crop_start+self.crop_len]
175
+ structure_tokens = esm_data['struct_tokens'][crop_start:crop_start+self.crop_len]
176
+ sequence_id = esm_data['sequence_id'][crop_start:crop_start+self.crop_len]
177
+ else:
178
+ sequence_tokens = esm_data['seq_tokens'][:self.crop_len]
179
+ structure_tokens = esm_data['struct_tokens'][:self.crop_len]
180
+ sequence_id = esm_data['sequence_id'][:self.crop_len]
181
+
182
+ eval_mask = np.zeros_like(sequence_tokens)
183
+ eval_mask[:esm_data['len']] = 1
184
+ eval_mask[sequence_tokens==self.proline] = 0
185
+
186
+ return sequence_tokens, structure_tokens, sequence_id, eval_mask
187
+
188
+ def __esm2_get_item__(self, obj, crop_start):
189
+ """
190
+ Prepares input for the ESM2 model
191
+ """
192
+ sequence = obj['sequence'].replace(' ','')
193
+ if crop_start > -1:
194
+ sequence = sequence[crop_start:crop_start+self.crop_len]
195
+
196
+ sequence_tokens = self.tokenizer.encode(sequence,
197
+ add_special_tokens=False,
198
+ padding='max_length',
199
+ max_length=self.crop_len,
200
+ return_tensors='np').squeeze()
201
+
202
+ # Set mask to 1 for length of sequence, padded tokens then are 0
203
+ eval_mask = np.zeros_like(sequence_tokens)
204
+ eval_mask[:len(sequence)] = 1
205
+ sequence_id = eval_mask != 0
206
+ eval_mask[sequence_tokens==self.proline] = 0
207
+
208
+ return sequence_tokens, sequence_id, eval_mask
209
+
210
+ def __get_dssp__(self, name, crop_start):
211
+ """
212
+ Prepares DSSP information for a given sequence
213
+ """
214
+ try:
215
+ dssp_csv = pd.read_csv('data/dssp.csv')
216
+ entry = dssp_csv.loc[dssp_csv.PDB == str(name)].iloc[0]
217
+ except:
218
+ entry = {}
219
+ entry['DSSP'] = utils.calc_dssp(f'pdbs/{name}.pdb')
220
+
221
+ assert len(entry) > 0
222
+ if crop_start ==-1:
223
+ dssp_data = entry['DSSP'].replace(' ','')[:self.crop_len]
224
+ else:
225
+ dssp_data = entry['DSSP'].replace(' ','')[crop_start:crop_start + self.crop_len]
226
+
227
+ dssp = np.zeros(self.crop_len)
228
+ inds = [i for i, char in enumerate(dssp_data) if char=='C']
229
+ dssp[inds] = 1.0
230
+
231
+ inds = [i for i, char in enumerate(dssp_data) if char=='H']
232
+ dssp[inds] = 2.0
233
+
234
+ return dssp
235
+
236
+ def __getitem__(self, idx):
237
+ """
238
+ Returns a dict with the appropriate entries for each model
239
+ """
240
+ exists = -1
241
+ while exists == -1:
242
+ name = self.names[idx]
243
+ if self.sample_clusters:
244
+ roptions = list(self.cluster_info.loc[self.cluster_info.cluster==name]['entry_ID'].values)
245
+ options = [opt for opt in roptions if opt in self.all_names]
246
+ name = random.choice(options)
247
+ pkl_fname = f"data/mBMRB_data/{name}.pkl"
248
+
249
+ try:
250
+ with open(pkl_fname, "rb") as f:
251
+ obj = pickle.load(f)
252
+ exists = 1
253
+ except:
254
+ print(f'{pkl_fname} not found')
255
+
256
+ assns = obj['label']
257
+ assns = vocab.mask_termini(assns)
258
+
259
+ crop_start = -1
260
+ if len(assns) > self.crop_len:
261
+ crop_start = np.random.choice(range(0, len(assns)-self.crop_len))
262
+ assns = assns[crop_start:crop_start + self.crop_len]
263
+
264
+ labels = self.label_tokenizer.convert_tokens_to_ids(assns, pad_to_length=self.crop_len)
265
+ labels = np.asarray(labels, np.int64)
266
+
267
+ dssp = None
268
+ if self.return_dssp:
269
+ dssp = self.__get_dssp__(name, crop_start)
270
+ if 'esm3' in self.method:
271
+ sequence, structure, sequence_id, eval_mask = self.__esm3_get_item__(name, crop_start)
272
+ elif 'esm2' in self.method:
273
+ sequence, sequence_id, eval_mask = self.__esm2_get_item__(obj, crop_start)
274
+ elif 'af2' in self.method:
275
+ pair_rep, sequence_id, eval_mask = self.__af2_get_item__(name, obj, crop_start)
276
+ elif 'baseline' in self.method:
277
+ sequence, sequence_id, eval_mask = self.__baseline_get_item__(name, obj, crop_start)
278
+
279
+ # Mask termini for eval. A -1 label corresponds to indices that are getting masked in vocabs
280
+ eval_mask[labels==-1] = 0
281
+
282
+ if 'esm2' in self.method:
283
+ return sequence, sequence_id, eval_mask, labels, name, dssp
284
+ elif 'esm3' in self.method:
285
+ return sequence, structure, sequence_id, eval_mask, labels, name, dssp
286
+ elif 'af2' in self.method:
287
+ return pair_rep, labels, sequence_id, eval_mask, name, dssp
288
+ elif 'baseline' in self.method:
289
+ return sequence, labels, sequence_id, eval_mask, name, dssp
290
+
291
+ def __collate_fn__(self, batch: List[Tuple[Any, ...]]):
292
+
293
+ if 'baseline' in self.method:
294
+ seqs, labels, sequence_id, eval_mask, names, dssp = tuple(zip(*batch))
295
+ seqs = torch.tensor(np.array(seqs))
296
+
297
+ labels = torch.from_numpy(np.array(labels)).float()
298
+ eval_mask = torch.from_numpy(np.array(eval_mask))
299
+ if self.return_dssp:
300
+ dssp = torch.from_numpy(np.array(dssp))
301
+ sequence_id = torch.from_numpy(np.array(sequence_id))
302
+
303
+ output = {'names': names, 'seqs': seqs, 'seq_id': sequence_id,
304
+ 'targets': labels, 'eval_mask': eval_mask, 'dssp': dssp}
305
+ return output
306
+
307
+ elif 'af2' in self.method:
308
+ pair_reps, labels, sequence_id, eval_mask, names, dssp = tuple(zip(*batch))
309
+
310
+ pair_reps = torch.from_numpy(np.array(pair_reps, dtype=np.float64))
311
+ labels = torch.from_numpy(np.array(labels)).float()
312
+ eval_mask = torch.from_numpy(np.array(eval_mask))
313
+ sequence_id = torch.from_numpy(np.array(sequence_id, dtype=bool))
314
+
315
+ if self.return_dssp:
316
+ dssp = torch.from_numpy(np.array(dssp))
317
+ output = {'names': names, 'pair_reps': pair_reps, 'targets': labels, "seq_id": sequence_id,
318
+ 'eval_mask': eval_mask, 'dssp': dssp}
319
+ return output
320
+
321
+ elif 'esm2' in self.method:
322
+ seqs, sequence_id, eval_mask, label, names, dssp = tuple(zip(*batch))
323
+ seqs = torch.from_numpy(np.array(seqs))
324
+ structs = None
325
+ sequence_id = torch.from_numpy(np.array(sequence_id))
326
+
327
+ elif 'esm3' in self.method:
328
+ seqs, structs, sequence_id, eval_mask, label, names, dssp = tuple(zip(*batch))
329
+ seqs = torch.from_numpy(np.array(seqs))
330
+ structs = torch.from_numpy(np.array(structs))
331
+ sequence_id = torch.from_numpy(np.array(sequence_id))
332
+
333
+ eval_mask = torch.from_numpy(np.array(eval_mask))
334
+ label = torch.from_numpy(np.array(label)).float()
335
+ if self.return_dssp:
336
+ dssp = torch.from_numpy(np.array(dssp))
337
+
338
+ output = {'seqs': seqs, "structs": structs, "seq_id": sequence_id, "eval_mask": eval_mask,
339
+ 'targets': label, 'names': names, 'dssp': dssp}
340
+
341
+ return output
Dyna-1/data/vocab.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import List
3
+
4
+ VOCAB_MISSING = OrderedDict([
5
+ ('t',-1), # no data bc disorderd terminus
6
+ ('x',-1), # no data, R1/R2/NOE not reported
7
+ ('p',0), # proline, not evaluated
8
+ ('A', 0), # nothing
9
+ ('v',0), # fast motion
10
+ ('.',1), # missing
11
+ ('b',0), # both fast and slow
12
+ ('^',0) # rex
13
+ ])
14
+
15
+ VOCAB_REX = OrderedDict([
16
+ ('t',-1), # no data bc disorderd terminus
17
+ ('x',-1), # no data, R1/R2/NOE not reported
18
+ ('p',0), # proline, not evaluated
19
+ ('A', 0), # nothing
20
+ ('v',0), # fast motion
21
+ ('.',1), # missing
22
+ ('b',1), # both fast and slow
23
+ ('^',1) # rex
24
+ ])
25
+
26
+ VOCAB_CPMG = OrderedDict([
27
+ ('t',-1), # no data bc disordered terminus
28
+ ('P',0), # proline, not evaluated
29
+ ('N',-1), # no data, assned but CPMG not reported
30
+ ('A', 0), # nothing
31
+ ('.',1), # missing
32
+ ('X',1), # exchange from Rex definition
33
+ ('Y',0) # exchange from unsuppressed R2
34
+ ])
35
+
36
+ def mask_termini(seq):
37
+ """
38
+ Mask the termini of a sequence
39
+ """
40
+ seq = seq.lstrip('.p').rjust(len(seq), 't')
41
+ seq = seq.rstrip('.p').ljust(len(seq), 't')
42
+ return seq
43
+
44
+ class label_tokenizer():
45
+ def __init__(self,
46
+ type = 'missing',
47
+ missing_only = False,
48
+ rex_only = False,
49
+ unsuppressed = False):
50
+ """
51
+ Tokenize the data labeling for BMRB, REX, and CPMG
52
+
53
+ Args:
54
+ type: (str) which type of experiment
55
+ missing_only: (bool) only return residues with missing peaks
56
+ rex_only: (bool) only return residues with Rex
57
+ unsuppressed: (bool) return residues with unsuppressed Rex
58
+ """
59
+ if type == 'missing':
60
+ self.vocab = VOCAB_MISSING.copy()
61
+ elif type == 'rex':
62
+ self.vocab = VOCAB_REX.copy()
63
+ if missing_only:
64
+ self.vocab['b'] = 0
65
+ self.vocab['^'] = 0
66
+ elif type == 'cpmg':
67
+ self.vocab = VOCAB_CPMG.copy()
68
+ if missing_only:
69
+ self.vocab['X'] = 0
70
+ if unsuppressed:
71
+ self.vocab['Y'] = 1
72
+ if rex_only:
73
+ self.vocab['.'] = -1
74
+ self.tokens = list(self.vocab.keys())
75
+
76
+ @property
77
+ def vocab_size(self) -> int:
78
+ return len(self.vocab)
79
+
80
+ def convert_token_to_id(self, token: str) -> int:
81
+ """ Converts a token (str/unicode) in an id using the vocab. """
82
+ try:
83
+ return self.vocab[token]
84
+ except KeyError:
85
+ raise KeyError(f"Unrecognized token: '{token}'")
86
+
87
+ def convert_tokens_to_ids(self, tokens: List[str], pad_to_length = None) -> List[int]:
88
+ """Converts a list of tokens (str/unicode) into a list of ids (int) using the vocab. """
89
+
90
+ if pad_to_length is None:
91
+ return [self.convert_token_to_id(token) for token in tokens]
92
+ else:
93
+ return [self.convert_token_to_id(token) for token in tokens] + [0] * (pad_to_length - len(tokens))
94
+
95
+ def convert_id_to_token(self, index: int) -> str:
96
+ """Converts an index (integer) in a token (string/unicode) using the vocab."""
97
+ try:
98
+ return self.tokens[index]
99
+ except IndexError:
100
+ raise IndexError(f"Unrecognized index: '{index}'")
101
+
102
+ def convert_ids_to_tokens(self, indices: List[int]) -> List[str]:
103
+ """Converts a list of indices (integer) into a list of tokens (string/unicode) using the vocab."""
104
+ return [self.convert_id_to_token(id_) for id_ in indices]
Dyna-1/esm/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __version__ = "3.1.1"
2
+
Dyna-1/esm/data/ParentChildTreeFile.txt ADDED
The diff for this file is too large to render. See raw diff
 
Dyna-1/esm/data/entry_list_safety_29026.list ADDED
The diff for this file is too large to render. See raw diff
 
Dyna-1/esm/data/interpro_29026_to_keywords_58641.csv ADDED
The diff for this file is too large to render. See raw diff
 
Dyna-1/esm/data/keyword_idf_safety_filtered_58641.npy ADDED
Binary file (469 kB). View file
 
Dyna-1/esm/data/keyword_vocabulary_safety_filtered_58641.txt ADDED
The diff for this file is too large to render. See raw diff
 
Dyna-1/esm/layers/attention.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import einops
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+ from esm.layers.rotary import RotaryEmbedding
9
+
10
+
11
+ class MultiHeadAttention(nn.Module):
12
+ def __init__(
13
+ self, d_model: int, n_heads: int, bias: bool = False, qk_layernorm: bool = True
14
+ ):
15
+ super().__init__()
16
+
17
+ self.d_model = d_model
18
+ self.n_heads = n_heads
19
+
20
+ self.d_head = self.d_model // self.n_heads
21
+ self.layernorm_qkv = nn.Sequential(
22
+ nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=bias)
23
+ )
24
+ self.out_proj = nn.Linear(d_model, d_model, bias=bias)
25
+
26
+ if qk_layernorm:
27
+ self.q_ln = nn.LayerNorm(d_model, bias=bias)
28
+ self.k_ln = nn.LayerNorm(d_model, bias=bias)
29
+ else:
30
+ self.q_ln = nn.Identity()
31
+ self.k_ln = nn.Identity()
32
+
33
+ self.rotary = RotaryEmbedding(d_model // n_heads)
34
+
35
+ def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor):
36
+ q = q.unflatten(-1, (self.n_heads, self.d_head))
37
+ k = k.unflatten(-1, (self.n_heads, self.d_head))
38
+ q, k = self.rotary(q, k)
39
+ q = q.flatten(-2, -1)
40
+ k = k.flatten(-2, -1)
41
+ return q, k
42
+
43
+ def forward(self, x, seq_id):
44
+ qkv_BLD3 = self.layernorm_qkv(x)
45
+ query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
46
+ query_BLD, key_BLD = (
47
+ self.q_ln(query_BLD).to(query_BLD.dtype),
48
+ self.k_ln(key_BLD).to(query_BLD.dtype),
49
+ )
50
+ query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)
51
+
52
+ n_heads = self.n_heads
53
+ reshaper = functools.partial(
54
+ einops.rearrange, pattern="b s (h d) -> b h s d", h=n_heads
55
+ )
56
+
57
+ query_BHLD, key_BHLD, value_BHLD = map(
58
+ reshaper, (query_BLD, key_BLD, value_BLD)
59
+ )
60
+
61
+ if seq_id is not None:
62
+ # Where True, enable participation in attention.
63
+ mask_BLL = seq_id.unsqueeze(-1) == seq_id.unsqueeze(-2)
64
+ mask_BHLL = mask_BLL.unsqueeze(1)
65
+
66
+ context_BHLD = F.scaled_dot_product_attention(
67
+ query_BHLD, key_BHLD, value_BHLD, mask_BHLL
68
+ )
69
+ else:
70
+ # Shortcut, if we don't use attention biases then torch
71
+ # will autoselect flashattention as the implementation
72
+ context_BHLD = F.scaled_dot_product_attention(
73
+ query_BHLD, key_BHLD, value_BHLD
74
+ )
75
+ context_BLD = einops.rearrange(context_BHLD, "b h s d -> b s (h d)")
76
+ return self.out_proj(context_BLD)
Dyna-1/esm/layers/blocks.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from esm.layers.attention import MultiHeadAttention
6
+ from esm.layers.geom_attention import (
7
+ GeometricReasoningOriginalImpl,
8
+ )
9
+ from esm.utils.structure.affine3d import Affine3D
10
+
11
+
12
+ def swiglu_correction_fn(expansion_ratio: float, d_model: int) -> int:
13
+ # set hidden dimesion to nearest multiple of 256 after expansion ratio
14
+ return int(((expansion_ratio * d_model) + 255) // 256 * 256)
15
+
16
+
17
+ class SwiGLU(nn.Module):
18
+ """
19
+ SwiGLU activation function as an nn.Module, allowing it to be used within nn.Sequential.
20
+ This module splits the input tensor along the last dimension and applies the SiLU (Swish)
21
+ activation function to the first half, then multiplies it by the second half.
22
+ """
23
+
24
+ def __init__(self):
25
+ super(SwiGLU, self).__init__()
26
+
27
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
28
+ x1, x2 = x.chunk(2, dim=-1)
29
+ return F.silu(x1) * x2
30
+
31
+
32
+ def swiglu_ln_ffn(d_model: int, expansion_ratio: float, bias: bool):
33
+ return nn.Sequential(
34
+ nn.LayerNorm(d_model),
35
+ nn.Linear(
36
+ d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=bias
37
+ ),
38
+ SwiGLU(),
39
+ nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=bias),
40
+ )
41
+
42
+
43
+ def gelu_ln_ffn(d_model: int, expansion_ratio: float, bias: bool):
44
+ hidden_dim = int(expansion_ratio * d_model)
45
+ return nn.Sequential(
46
+ nn.LayerNorm(d_model),
47
+ nn.Linear(d_model, hidden_dim, bias=bias),
48
+ nn.GELU(),
49
+ nn.Linear(hidden_dim, d_model, bias=bias),
50
+ )
51
+
52
+
53
+ class UnifiedTransformerBlock(nn.Module):
54
+ """
55
+ A unified transformer block that can optionally incorporate geometric attention.
56
+
57
+ This class defines a transformer block that can be configured to use geometric attention
58
+ alongside the standard multi-head attention mechanism. It is designed to be a flexible
59
+ component of transformer-based models, allowing for the integration of geometric reasoning.
60
+
61
+ Parameters
62
+ ----------
63
+ d_model : int
64
+ The dimensionality of the input and output features of the transformer block.
65
+ n_heads : int
66
+ The number of attention heads in the multi-head attention mechanism.
67
+ n_layers : int
68
+ The number of layers in the transformer block.
69
+ use_geom_attn : bool, optional
70
+ Whether to use geometric attention in addition to the standard multi-head attention. Defaults to False.
71
+ v_heads : int, optional
72
+ The number of heads to use for the geometric attention mechanism, if enabled. Must be specified if `use_geom_attn` is True.
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ d_model: int,
78
+ n_heads: int,
79
+ use_geom_attn: bool = False,
80
+ use_plain_attn: bool = True,
81
+ v_heads: int | None = None,
82
+ bias: bool = False,
83
+ expansion_ratio: float = 4.0,
84
+ residue_scaling_factor: float = 1,
85
+ mask_and_zero_frameless: bool = False,
86
+ qk_layernorm: bool = True,
87
+ ffn_type: str = "swiglu", # swiglu | gelu
88
+ ):
89
+ super().__init__()
90
+ self.use_plain_attn = use_plain_attn
91
+ if self.use_plain_attn:
92
+ self.attn = MultiHeadAttention(
93
+ d_model, n_heads, bias, qk_layernorm=qk_layernorm
94
+ )
95
+ self.use_geom_attn = use_geom_attn
96
+ if self.use_geom_attn:
97
+ if v_heads is None:
98
+ raise ValueError("v_heads must be specified when use_geom_attn is True")
99
+ self.geom_attn = GeometricReasoningOriginalImpl(
100
+ c_s=d_model,
101
+ v_heads=v_heads,
102
+ bias=bias,
103
+ mask_and_zero_frameless=mask_and_zero_frameless,
104
+ )
105
+ if ffn_type == "swiglu":
106
+ self.ffn = swiglu_ln_ffn(d_model, expansion_ratio, bias)
107
+ elif ffn_type == "gelu":
108
+ self.ffn = gelu_ln_ffn(d_model, expansion_ratio, bias)
109
+ else:
110
+ raise ValueError(f"Unknown ffn_type: {ffn_type}")
111
+ self.scaling_factor = residue_scaling_factor
112
+
113
+ def forward(
114
+ self,
115
+ x: torch.Tensor,
116
+ sequence_id: torch.Tensor,
117
+ frames: Affine3D,
118
+ frames_mask: torch.Tensor,
119
+ chain_id: torch.Tensor,
120
+ ) -> torch.Tensor:
121
+ """
122
+ Forward pass for the UnifiedTransformerBlock.
123
+
124
+ Parameters
125
+ ----------
126
+ x : torch.Tensor[float]
127
+ Input tensor to the transformer block, typically the output from the previous layer.
128
+ sequence_id : torch.Tensor[int]
129
+ Tensor containing sequence IDs for each element in the batch, used for attention masking.
130
+ frames : Affine3D
131
+ Affine3D containing geometric frame information for geometric attention.
132
+ frames_mask : torch.Tensor[bool]
133
+ Boolean mask tensor indicating valid frames for geometric attention.
134
+ chain_id : torch.Tensor[int]
135
+ Tensor containing chain IDs for each element, used for attention masking in geometric attention.
136
+
137
+ Returns
138
+ -------
139
+ torch.Tensor[float]
140
+ The output tensor after applying the transformer block operations.
141
+ """
142
+ if self.use_plain_attn:
143
+ r1 = self.attn(x, sequence_id)
144
+ x = x + r1 / self.scaling_factor
145
+
146
+ if self.use_geom_attn:
147
+ r2 = self.geom_attn(x, frames, frames_mask, sequence_id, chain_id)
148
+ x = x + r2 / self.scaling_factor
149
+
150
+ r3 = self.ffn(x) / self.scaling_factor
151
+ x = x + r3
152
+
153
+ return x
Dyna-1/esm/layers/codebook.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.distributed as dist
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class EMACodebook(nn.Module):
9
+ def __init__(
10
+ self,
11
+ n_codes,
12
+ embedding_dim,
13
+ no_random_restart=True,
14
+ restart_thres=1.0,
15
+ ema_decay=0.99,
16
+ ):
17
+ super().__init__()
18
+ self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim))
19
+ self.register_buffer("N", torch.zeros(n_codes))
20
+ self.register_buffer("z_avg", self.embeddings.data.clone())
21
+
22
+ self.n_codes = n_codes
23
+ self.embedding_dim = embedding_dim
24
+ self._need_init = True
25
+ self.no_random_restart = no_random_restart
26
+ self.restart_thres = restart_thres
27
+ self.freeze_codebook = False
28
+ self.ema_decay = ema_decay
29
+
30
+ def reset_parameters(self):
31
+ # For meta init
32
+ pass
33
+
34
+ def _tile(self, x):
35
+ d, ew = x.shape
36
+ if d < self.n_codes:
37
+ n_repeats = (self.n_codes + d - 1) // d
38
+ std = 0.01 / np.sqrt(ew)
39
+ x = x.repeat(n_repeats, 1)
40
+ x = x + torch.randn_like(x) * std
41
+ return x
42
+
43
+ def _init_embeddings(self, z):
44
+ # z: [b, t, c]
45
+ self._need_init = False
46
+ flat_inputs = z.view(-1, self.embedding_dim)
47
+ y = self._tile(flat_inputs)
48
+
49
+ y.shape[0]
50
+ _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes]
51
+ if dist.is_initialized():
52
+ dist.broadcast(_k_rand, 0)
53
+ self.embeddings.data.copy_(_k_rand)
54
+ self.z_avg.data.copy_(_k_rand)
55
+ self.N.data.copy_(torch.ones(self.n_codes))
56
+
57
+ def forward(self, z):
58
+ # z: [b, t, c]
59
+ if self._need_init and self.training and not self.freeze_codebook:
60
+ self._init_embeddings(z)
61
+ # z is of shape [batch_size, sequence length, channels]
62
+ flat_inputs = z.view(-1, self.embedding_dim)
63
+ distances = (
64
+ (flat_inputs**2).sum(dim=1, keepdim=True)
65
+ - 2 * flat_inputs @ self.embeddings.t()
66
+ + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True)
67
+ ) # [bt, c]
68
+
69
+ encoding_indices = torch.argmin(distances, dim=1)
70
+ encoding_indices = encoding_indices.view(*z.shape[:2]) # [b, t, ncode]
71
+
72
+ embeddings = F.embedding(encoding_indices, self.embeddings) # [b, t, c]
73
+
74
+ commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach())
75
+
76
+ # EMA codebook update
77
+ if self.training and not self.freeze_codebook:
78
+ assert False, "Not implemented"
79
+ embeddings_st = (embeddings - z).detach() + z
80
+
81
+ return embeddings_st, encoding_indices, commitment_loss
82
+
83
+ def dictionary_lookup(self, encodings):
84
+ embeddings = F.embedding(encodings, self.embeddings)
85
+ return embeddings
86
+
87
+ def soft_codebook_lookup(self, weights: torch.Tensor) -> torch.Tensor:
88
+ return weights @ self.embeddings
Dyna-1/esm/layers/ffn.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ from torch import Tensor
4
+
5
+ # NOT CURRENTLY USED
6
+
7
+
8
+ class SwiGLU(nn.Module):
9
+ def __init__(self) -> None:
10
+ super().__init__()
11
+
12
+ def forward(self, x: Tensor) -> Tensor:
13
+ x1, x2 = x.chunk(2, dim=-1)
14
+ hidden = F.silu(x1) * x2
15
+ return hidden
16
+
17
+
18
+ class FFN(nn.Module):
19
+ def __init__(self, in_proj, activation, out_proj) -> None:
20
+ super().__init__()
21
+ self.in_proj = in_proj
22
+ self.activation = activation
23
+ self.out_proj = out_proj
24
+
25
+ def forward(self, x: Tensor) -> Tensor:
26
+ x = self.in_proj(x)
27
+ x = self.activation(x)
28
+ x = self.out_proj(x)
29
+ return x
Dyna-1/esm/layers/geom_attention.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import sqrt
2
+
3
+ import torch
4
+ from einops import rearrange
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+
9
+ class GeometricReasoningOriginalImpl(nn.Module):
10
+ def __init__(
11
+ self,
12
+ c_s: int,
13
+ v_heads: int,
14
+ num_vector_messages: int = 1,
15
+ mask_and_zero_frameless: bool = True,
16
+ divide_residual_by_depth: bool = False,
17
+ bias: bool = False,
18
+ ):
19
+ """Approximate implementation:
20
+
21
+ ATTN(A, v) := (softmax_j A_ij) v_j
22
+ make_rot_vectors(x) := R(i->g) Linear(x).reshape(..., 3)
23
+ make_vectors(x) := T(i->g) Linear(x).reshape(..., 3)
24
+
25
+ v <- make_rot_vectors(x)
26
+ q_dir, k_dir <- make_rot_vectors(x)
27
+ q_dist, k_dist <- make_vectors(x)
28
+
29
+ A_ij <- dot(q_dir_i, k_dir_j) -||q_dist_i - k_dist_j||^2
30
+ x <- x + Linear(T(g->i) ATTN(A, v))
31
+ """
32
+ super().__init__()
33
+ self.c_s = c_s
34
+ self.v_heads = v_heads
35
+ self.num_vector_messages = num_vector_messages
36
+ self.mask_and_zero_frameless = mask_and_zero_frameless
37
+
38
+ self.s_norm = nn.LayerNorm(c_s, bias=bias)
39
+ dim_proj = (
40
+ 4 * self.v_heads * 3 + self.v_heads * 3 * self.num_vector_messages
41
+ ) # 2 x (q, k) * number of heads * (x, y, z) + number of heads * number of vector messages * (x, y, z)
42
+ self.proj = nn.Linear(c_s, dim_proj, bias=bias)
43
+ channels_out = self.v_heads * 3 * self.num_vector_messages
44
+ self.out_proj = nn.Linear(channels_out, c_s, bias=bias)
45
+
46
+ # The basic idea is for some attention heads to pay more or less attention to rotation versus distance,
47
+ # as well as to control the sharpness of the softmax (i.e., should this head only attend to those residues
48
+ # very nearby or should there be shallower dropoff in attention weight?)
49
+ self.distance_scale_per_head = nn.Parameter(torch.zeros((self.v_heads)))
50
+ self.rotation_scale_per_head = nn.Parameter(torch.zeros((self.v_heads)))
51
+
52
+ def forward(self, s, affine, affine_mask, sequence_id, chain_id):
53
+ if sequence_id is None:
54
+ sequence_id = torch.zeros_like(s[..., 0], dtype=torch.int64)
55
+ attn_bias = sequence_id.unsqueeze(-1) == sequence_id.unsqueeze(-2)
56
+ attn_bias = attn_bias.unsqueeze(1).float()
57
+ attn_bias = attn_bias.masked_fill(
58
+ ~affine_mask[:, None, None, :], torch.finfo(attn_bias.dtype).min
59
+ )
60
+ chain_id_mask = chain_id.unsqueeze(1) != chain_id.unsqueeze(2)
61
+ attn_bias = attn_bias.masked_fill(
62
+ chain_id_mask.unsqueeze(1), torch.finfo(s.dtype).min
63
+ )
64
+
65
+ ns = self.s_norm(s)
66
+ vec_rot, vec_dist = self.proj(ns).split(
67
+ [
68
+ self.v_heads * 2 * 3 + self.v_heads * 3 * self.num_vector_messages,
69
+ self.v_heads * 2 * 3,
70
+ ],
71
+ dim=-1,
72
+ )
73
+
74
+ # Rotate the queries and keys for the rotation term. We also rotate the values.
75
+ # NOTE(zeming, thayes): Values are only rotated, not translated. We may wish to change
76
+ # this in the future.
77
+ query_rot, key_rot, value = (
78
+ affine.rot[..., None]
79
+ .apply(rearrange(vec_rot, "... (h c) -> ... h c", c=3))
80
+ .split(
81
+ [self.v_heads, self.v_heads, self.v_heads * self.num_vector_messages],
82
+ dim=-2,
83
+ )
84
+ )
85
+
86
+ # Rotate and translate the queries and keys for the distance term
87
+ # NOTE(thayes): a simple speedup would be to apply all rotations together, then
88
+ # separately apply the translations.
89
+ query_dist, key_dist = (
90
+ affine[..., None]
91
+ .apply(rearrange(vec_dist, "... (h c) -> ... h c", c=3))
92
+ .chunk(2, dim=-2)
93
+ )
94
+
95
+ query_dist = rearrange(query_dist, "b s h d -> b h s 1 d")
96
+ key_dist = rearrange(key_dist, "b s h d -> b h 1 s d")
97
+ query_rot = rearrange(query_rot, "b s h d -> b h s d")
98
+ key_rot = rearrange(key_rot, "b s h d -> b h d s")
99
+ value = rearrange(
100
+ value, "b s (h m) d -> b h s (m d)", m=self.num_vector_messages
101
+ )
102
+
103
+ distance_term = (query_dist - key_dist).norm(dim=-1) / sqrt(3)
104
+ rotation_term = query_rot.matmul(key_rot) / sqrt(3)
105
+ distance_term_weight = rearrange(
106
+ F.softplus(self.distance_scale_per_head), "h -> h 1 1"
107
+ )
108
+ rotation_term_weight = rearrange(
109
+ F.softplus(self.rotation_scale_per_head), "h -> h 1 1"
110
+ )
111
+
112
+ attn_weight = (
113
+ rotation_term * rotation_term_weight - distance_term * distance_term_weight
114
+ )
115
+
116
+ if attn_bias is not None:
117
+ # we can re-use the attention bias from the transformer layers
118
+ # NOTE(thayes): This attention bias is expected to handle two things:
119
+ # 1. Masking attention on padding tokens
120
+ # 2. Masking cross sequence attention in the case of bin packing
121
+ s_q = attn_weight.size(2)
122
+ s_k = attn_weight.size(3)
123
+ _s_q = max(0, attn_bias.size(2) - s_q)
124
+ _s_k = max(0, attn_bias.size(3) - s_k)
125
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
126
+ attn_weight = attn_weight + attn_bias
127
+
128
+ attn_weight = torch.softmax(attn_weight, dim=-1)
129
+
130
+ attn_out = attn_weight.matmul(value)
131
+
132
+ attn_out = (
133
+ affine.rot[..., None]
134
+ .invert()
135
+ .apply(
136
+ rearrange(
137
+ attn_out, "b h s (m d) -> b s (h m) d", m=self.num_vector_messages
138
+ )
139
+ )
140
+ )
141
+
142
+ attn_out = rearrange(
143
+ attn_out, "b s (h m) d -> b s (h m d)", m=self.num_vector_messages
144
+ )
145
+ if self.mask_and_zero_frameless:
146
+ attn_out = attn_out.masked_fill(~affine_mask[..., None], 0.0)
147
+ s = self.out_proj(attn_out)
148
+
149
+ return s
Dyna-1/esm/layers/regression_head.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def RegressionHead(
5
+ d_model: int, output_dim: int, hidden_dim: int | None = None
6
+ ) -> nn.Module:
7
+ """Single-hidden layer MLP for supervised output.
8
+
9
+ Args:
10
+ d_model: input dimension
11
+ output_dim: dimensionality of the output.
12
+ hidden_dim: optional dimension of hidden layer, defaults to d_model.
13
+ Returns:
14
+ output MLP module.
15
+ """
16
+ hidden_dim = hidden_dim if hidden_dim is not None else d_model
17
+ return nn.Sequential(
18
+ nn.Linear(d_model, hidden_dim),
19
+ nn.GELU(),
20
+ nn.LayerNorm(hidden_dim),
21
+ nn.Linear(hidden_dim, output_dim),
22
+ )
Dyna-1/esm/layers/rotary.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
4
+ # and OPT implementations in this library. It has been modified from its
5
+ # original forms to accommodate minor architectural differences compared
6
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ # NOTE: this implementation is from LLaMA 2:
20
+ # https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/08639a72e17836184096ae6a7e2766f2a34c3e36/modeling_flash_llama.py#L114
21
+ # Flash attention rotary implementation can be installed like so: `pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary`
22
+
23
+ from typing import Tuple
24
+
25
+ import torch
26
+ from einops import rearrange, repeat
27
+
28
+
29
+ def rotate_half(x, interleaved=False):
30
+ if not interleaved:
31
+ x1, x2 = x.chunk(2, dim=-1)
32
+ return torch.cat((-x2, x1), dim=-1)
33
+ else:
34
+ x1, x2 = x[..., ::2], x[..., 1::2]
35
+ return rearrange(
36
+ torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
37
+ )
38
+
39
+
40
+ def apply_rotary_emb_torch(x, cos, sin, interleaved=False, _inplace=False):
41
+ """
42
+ x: (batch_size, seqlen, nheads, headdim)
43
+ cos, sin: (seqlen, rotary_dim / 2)
44
+ """
45
+ ro_dim = cos.shape[-1] * 2
46
+ assert ro_dim <= x.shape[-1]
47
+ seqlen = x.size(1)
48
+ cos = cos[:seqlen]
49
+ sin = sin[:seqlen]
50
+ cos = repeat(cos, "s d -> s 1 (2 d)")
51
+ sin = repeat(sin, "s d -> s 1 (2 d)")
52
+ return torch.cat(
53
+ [
54
+ x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
55
+ x[..., ro_dim:],
56
+ ],
57
+ dim=-1,
58
+ )
59
+
60
+
61
+ class RotaryEmbedding(torch.nn.Module):
62
+ """
63
+ The rotary position embeddings from RoFormer_ (Su et. al).
64
+ A crucial insight from the method is that the query and keys are
65
+ transformed by rotation matrices which depend on the relative positions.
66
+ Other implementations are available in the Rotary Transformer repo_ and in
67
+ GPT-NeoX_, GPT-NeoX was an inspiration
68
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
69
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
70
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
71
+ If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
72
+ A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
73
+ Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ dim: int,
79
+ base=10000.0,
80
+ interleaved=False,
81
+ scale_base=None,
82
+ scaling_factor=1.0,
83
+ pos_idx_in_fp32=True,
84
+ device=None,
85
+ ):
86
+ """
87
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
88
+ of 1st half and 2nd half (GPT-NeoX style).
89
+ pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
90
+ otherwise they might be in lower precision.
91
+ This option was added because previously (before 2023-07-02), when we construct
92
+ the position indices, we use the dtype of self.inv_freq. In most cases this would
93
+ be fp32, but if the model is trained in pure bf16 (not mixed precision), then
94
+ self.inv_freq would be bf16, and the position indices are also in bf16.
95
+ Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
96
+ embeddings for some positions will coincide.
97
+ To maintain compatibility with models previously trained in pure bf16,
98
+ we add this option.
99
+ scaling_factor: RotaryEmbedding extended with linear scaling.
100
+ """
101
+ super().__init__()
102
+ self.dim = dim
103
+ self.base = float(base)
104
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
105
+ # Generate and save the inverse frequency buffer (non trainable)
106
+ self.interleaved = interleaved
107
+ self.scale_base = scale_base
108
+ self.scaling_factor = scaling_factor
109
+ self.device = device
110
+
111
+ self._seq_len_cached = 0
112
+ self._cos_cached = None
113
+ self._sin_cached = None
114
+ self._cos_k_cached = None
115
+ self._sin_k_cached = None
116
+ self.reset_parameters()
117
+
118
+ def reset_parameters(self):
119
+ inv_freq = self._compute_inv_freq(self.device)
120
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
121
+ arange = torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32)
122
+ scale = (
123
+ (arange + 0.4 * self.dim) / (1.4 * self.dim)
124
+ if self.scale_base is not None
125
+ else None
126
+ )
127
+ self.register_buffer("scale", scale)
128
+
129
+ def _compute_inv_freq(self, device=None):
130
+ return 1 / (
131
+ self.base
132
+ ** (
133
+ torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
134
+ / self.dim
135
+ )
136
+ )
137
+
138
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
139
+ # Reset the tables if the sequence length has changed,
140
+ # if we're on a new device (possibly due to tracing for instance),
141
+ # or if we're switching from inference mode to training
142
+ if (
143
+ seqlen > self._seq_len_cached
144
+ or self._cos_cached is None
145
+ or self._cos_cached.device != device
146
+ or self._cos_cached.dtype != dtype
147
+ or (self.training and self._cos_cached.is_inference())
148
+ ):
149
+ self._seq_len_cached = seqlen
150
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
151
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
152
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
153
+ if self.pos_idx_in_fp32:
154
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
155
+ t /= self.scaling_factor
156
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
157
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
158
+ # cos & sin output to change significantly.
159
+ # We want to recompute self.inv_freq if it was not loaded in fp32
160
+ if self.inv_freq.dtype != torch.float32:
161
+ inv_freq = self.inv_freq.to(torch.float32)
162
+ else:
163
+ inv_freq = self.inv_freq
164
+ else:
165
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
166
+ t /= self.scaling_factor
167
+ inv_freq = self.inv_freq
168
+ # Don't do einsum, it converts fp32 to fp16 under AMP
169
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
170
+ freqs = torch.outer(t, inv_freq)
171
+
172
+ if self.scale is None:
173
+ self._cos_cached = torch.cos(freqs).to(dtype)
174
+ self._sin_cached = torch.sin(freqs).to(dtype)
175
+ else:
176
+ power = (
177
+ torch.arange(
178
+ seqlen, dtype=self.scale.dtype, device=self.scale.device
179
+ )
180
+ - seqlen // 2
181
+ ) / self.scale_base
182
+ scale = self.scale.to(device=power.device) ** power.unsqueeze(-1)
183
+ # We want the multiplication by scale to happen in fp32
184
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
185
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
186
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
187
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
188
+
189
+ def forward(
190
+ self, q: torch.Tensor, k: torch.Tensor, seqlen_offset: int = 0
191
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
192
+ """
193
+ q: (batch, seqlen, nheads, headdim)
194
+ k: (batch, seqlen, nheads, headdim)
195
+ seqlen_offset: can be used in generation where the qkv being passed in is only the last
196
+ token in the batch.
197
+ """
198
+ self._update_cos_sin_cache(
199
+ q.shape[1] + seqlen_offset, device=q.device, dtype=q.dtype
200
+ )
201
+ assert self._cos_cached is not None
202
+ assert self._sin_cached is not None
203
+ if self.scale is None:
204
+ return (
205
+ apply_rotary_emb_torch(
206
+ q,
207
+ self._cos_cached[seqlen_offset:],
208
+ self._sin_cached[seqlen_offset:],
209
+ self.interleaved,
210
+ True, # inplace=True
211
+ ),
212
+ apply_rotary_emb_torch(
213
+ k,
214
+ self._cos_cached[seqlen_offset:],
215
+ self._sin_cached[seqlen_offset:],
216
+ self.interleaved,
217
+ True, # inplace=True
218
+ ),
219
+ ) # type: ignore
220
+ else:
221
+ assert False
Dyna-1/esm/layers/structure_proj.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from esm.utils.constants.physics import BB_COORDINATES
5
+ from esm.utils.structure.affine3d import (
6
+ Affine3D,
7
+ RotationMatrix,
8
+ )
9
+
10
+
11
+ class Dim6RotStructureHead(nn.Module):
12
+ # Normally, AF2 uses quaternions to specify rotations. There's some evidence that
13
+ # other representations are more well behaved - the best one according to
14
+ # https://openaccess.thecvf.com/content_CVPR_2019/papers/Zhou_On_the_Continuity_of_Rotation_Representations_in_Neural_Networks_CVPR_2019_paper.pdf
15
+ # is using graham schmidt on 2 vectors, which is implemented here.
16
+ def __init__(
17
+ self,
18
+ input_dim: int,
19
+ trans_scale_factor: float = 10,
20
+ norm_type: str = "layernorm",
21
+ activation_fn: str = "esm_gelu",
22
+ predict_torsion_angles: bool = True,
23
+ ):
24
+ super().__init__()
25
+ self.ffn1 = nn.Linear(input_dim, input_dim)
26
+ self.activation_fn = nn.GELU()
27
+ self.norm = nn.LayerNorm(input_dim)
28
+ self.proj = nn.Linear(input_dim, 9 + 7 * 2)
29
+ self.trans_scale_factor = trans_scale_factor
30
+ self.predict_torsion_angles = predict_torsion_angles
31
+ self.bb_local_coords = torch.tensor(BB_COORDINATES).float()
32
+
33
+ def forward(self, x, affine, affine_mask, **kwargs):
34
+ if affine is None:
35
+ rigids = Affine3D.identity(
36
+ x.shape[:-1],
37
+ dtype=x.dtype,
38
+ device=x.device,
39
+ requires_grad=self.training,
40
+ rotation_type=RotationMatrix,
41
+ )
42
+ else:
43
+ rigids = affine
44
+
45
+ # [*, N]
46
+ x = self.ffn1(x)
47
+ x = self.activation_fn(x)
48
+ x = self.norm(x)
49
+ trans, x, y, angles = self.proj(x).split([3, 3, 3, 7 * 2], dim=-1)
50
+ trans = trans * self.trans_scale_factor
51
+ x = x / (x.norm(dim=-1, keepdim=True) + 1e-5)
52
+ y = y / (y.norm(dim=-1, keepdim=True) + 1e-5)
53
+ update = Affine3D.from_graham_schmidt(x + trans, trans, y + trans)
54
+ rigids = rigids.compose(update.mask(affine_mask))
55
+ affine = rigids.tensor
56
+
57
+ # We approximate the positions of the backbone atoms in the global frame by applying the rigid
58
+ # transformation to the mean of the backbone atoms in the local frame.
59
+ all_bb_coords_local = (
60
+ self.bb_local_coords[None, None, :, :]
61
+ .expand(*x.shape[:-1], 3, 3)
62
+ .to(x.device)
63
+ )
64
+ pred_xyz = rigids[..., None].apply(all_bb_coords_local)
65
+
66
+ return affine, pred_xyz
Dyna-1/esm/layers/transformer_stack.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from esm.layers.blocks import UnifiedTransformerBlock
7
+ from esm.utils.structure.affine3d import Affine3D
8
+
9
+
10
+ class TransformerStack(nn.Module):
11
+ """
12
+ A stack of transformer blocks used in the ESM-3 model. Each block is a UnifiedTransformerBlock,
13
+ which can either be geometric attention or standard multi-head attention.
14
+
15
+ Args:
16
+ d_model (int): The dimensionality of the input and output feature vectors.
17
+ n_heads (int): The number of attention heads.
18
+ v_heads (int): The number of voting heads.
19
+ n_layers (int): The number of transformer blocks in the stack.
20
+ n_layers_geom (int, optional): The number of transformer blocks that use geometric attention.
21
+ scale_residue (bool, optional): Whether to scale the residue connections in each transformer block.
22
+ mask_and_zero_frameless (bool, optional): Whether to mask and zero frameless positions in the input.
23
+ Only applies in the geometric attention blocks, which is conditioned on the structure
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ d_model: int,
29
+ n_heads: int,
30
+ v_heads: int | None,
31
+ n_layers: int,
32
+ n_layers_geom: int = 1,
33
+ scale_residue: bool = True,
34
+ mask_and_zero_frameless: bool = False,
35
+ bias: bool = False,
36
+ qk_layernorm: bool = True,
37
+ ffn_type: str = "swiglu", # swiglu | gelu
38
+ expansion_ratio: float = 8 / 3,
39
+ ):
40
+ super().__init__()
41
+ self.blocks = nn.ModuleList(
42
+ [
43
+ UnifiedTransformerBlock(
44
+ d_model,
45
+ n_heads,
46
+ v_heads=v_heads,
47
+ use_geom_attn=i < n_layers_geom,
48
+ residue_scaling_factor=(
49
+ math.sqrt(n_layers / 36) if scale_residue else 1.0
50
+ ),
51
+ expansion_ratio=expansion_ratio,
52
+ mask_and_zero_frameless=mask_and_zero_frameless,
53
+ bias=bias,
54
+ qk_layernorm=qk_layernorm,
55
+ ffn_type=ffn_type,
56
+ )
57
+ for i in range(n_layers)
58
+ ]
59
+ )
60
+ self.norm = nn.LayerNorm(d_model, bias=False)
61
+
62
+ def forward(
63
+ self,
64
+ x: torch.Tensor,
65
+ sequence_id: torch.Tensor | None = None,
66
+ affine: Affine3D | None = None,
67
+ affine_mask: torch.Tensor | None = None,
68
+ chain_id: torch.Tensor | None = None,
69
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
70
+ """
71
+ Forward pass of the TransformerStack.
72
+
73
+ Args:
74
+ x (torch.Tensor): The input tensor of shape (batch_size, sequence_length, d_model).
75
+ sequence_id (torch.Tensor): The sequence ID tensor of shape (batch_size, sequence_length).
76
+ affine (Affine3D | None): The affine transformation tensor or None.
77
+ affine_mask (torch.Tensor | None): The affine mask tensor or None.
78
+ chain_id (torch.Tensor): The protein chain tensor of shape (batch_size, sequence_length).
79
+ Only used in geometric attention.
80
+
81
+ Returns:
82
+ post_norm: The output tensor of shape (batch_size, sequence_length, d_model).
83
+ pre_norm: The embedding of shape (batch_size, sequence_length, d_model).
84
+ """
85
+ *batch_dims, _ = x.shape
86
+ if chain_id is None:
87
+ chain_id = torch.ones(size=batch_dims, dtype=torch.int64, device=x.device)
88
+ hiddens = []
89
+ for block in self.blocks:
90
+ x = block(x, sequence_id, affine, affine_mask, chain_id)
91
+ hiddens.append(x)
92
+ hiddens = torch.stack(hiddens, dim=0)
93
+ return self.norm(x), x, hiddens
Dyna-1/esm/models/esm3.py ADDED
@@ -0,0 +1,606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ from functools import partial
5
+ from typing import Callable
6
+
7
+ import attr
8
+ import einops
9
+ import torch
10
+ import torch.nn as nn
11
+ from attr import dataclass
12
+
13
+ from esm.layers.regression_head import RegressionHead
14
+ from esm.layers.transformer_stack import TransformerStack
15
+ from esm.models.function_decoder import FunctionTokenDecoder
16
+ from esm.models.vqvae import (
17
+ StructureTokenDecoder,
18
+ StructureTokenEncoder,
19
+ )
20
+ from esm.sdk.api import (
21
+ ESM3InferenceClient,
22
+ ESMProtein,
23
+ ESMProteinTensor,
24
+ ForwardAndSampleOutput,
25
+ ForwardTrackData,
26
+ GenerationConfig,
27
+ LogitsConfig,
28
+ LogitsOutput,
29
+ ProteinType,
30
+ SamplingConfig,
31
+ )
32
+ from esm.tokenization import TokenizerCollectionProtocol
33
+ from esm.utils import encoding
34
+ from esm.utils.constants import esm3 as C
35
+ from esm.utils.constants.models import (
36
+ ESM3_OPEN_SMALL,
37
+ normalize_model_name,
38
+ )
39
+ from esm.utils.decoding import decode_protein_tensor
40
+ from esm.utils.generation import (
41
+ _batch_forward,
42
+ _sample_per_prompt,
43
+ _slice_tensor_dataclass,
44
+ iterative_sampling_raw,
45
+ iterative_sampling_tokens,
46
+ )
47
+ from esm.utils.misc import rbf
48
+ from esm.utils.sampling import (
49
+ _BatchedESMProteinTensor,
50
+ get_default_sampling_config,
51
+ validate_sampling_config,
52
+ )
53
+ from esm.utils.structure.affine3d import (
54
+ build_affine3d_from_coordinates,
55
+ )
56
+
57
+
58
+ @dataclass
59
+ class ESMOutput:
60
+ sequence_logits: torch.Tensor
61
+ structure_logits: torch.Tensor
62
+ secondary_structure_logits: torch.Tensor
63
+ sasa_logits: torch.Tensor
64
+ function_logits: torch.Tensor
65
+ residue_logits: torch.Tensor
66
+ embeddings: torch.Tensor
67
+
68
+
69
+ class EncodeInputs(nn.Module):
70
+ """
71
+ Module for encoding input features in the ESM-3 model.
72
+
73
+ Args:
74
+ d_model (int): The dimensionality of the model's hidden states.
75
+ """
76
+
77
+ def __init__(self, d_model: int):
78
+ super().__init__()
79
+
80
+ # Sequence
81
+ self.sequence_embed = nn.Embedding(64, d_model)
82
+ # Mandatory information
83
+ self.plddt_projection = nn.Linear(16, d_model)
84
+ self.structure_per_res_plddt_projection = nn.Linear(16, d_model)
85
+
86
+ # Structure
87
+ self.structure_tokens_embed = nn.Embedding(4096 + 5, d_model)
88
+
89
+ # "Structural" features
90
+ self.ss8_embed = nn.Embedding(8 + 3, d_model)
91
+ self.sasa_embed = nn.Embedding(16 + 3, d_model)
92
+
93
+ # "Functional" features
94
+ self.function_embed = nn.ModuleList(
95
+ [nn.Embedding(260, d_model // 8, padding_idx=0) for _ in range(8)]
96
+ )
97
+
98
+ self.residue_embed = nn.EmbeddingBag(1478, d_model, mode="sum", padding_idx=0)
99
+
100
+ def forward(
101
+ self,
102
+ sequence_tokens: torch.Tensor,
103
+ structure_tokens: torch.Tensor,
104
+ average_plddt: torch.Tensor,
105
+ per_res_plddt: torch.Tensor,
106
+ ss8_tokens: torch.Tensor,
107
+ sasa_tokens: torch.Tensor,
108
+ function_tokens: torch.Tensor,
109
+ residue_annotation_tokens: torch.Tensor,
110
+ ) -> torch.Tensor:
111
+ sequence_embed = self.sequence_embed(sequence_tokens)
112
+
113
+ rbf_16_fn = partial(rbf, v_min=0.0, v_max=1.0, n_bins=16)
114
+ # the `masked_fill(padding_mask.unsqueeze(2), 0)` for the two below is unnecessary
115
+ # as pad tokens never even interact with the "real" tokens (due to sequence_id)
116
+ plddt_embed = self.plddt_projection(rbf_16_fn(average_plddt))
117
+ structure_per_res_plddt = self.structure_per_res_plddt_projection(
118
+ rbf_16_fn(per_res_plddt)
119
+ )
120
+
121
+ # Structure + "structural features" embeds
122
+ structure_embed = self.structure_tokens_embed(structure_tokens)
123
+ ss8_embed = self.ss8_embed(ss8_tokens)
124
+ sasa_embed = self.sasa_embed(sasa_tokens)
125
+
126
+ # "Functional" features embeds
127
+ function_embed = torch.cat(
128
+ [
129
+ embed_fn(funcs)
130
+ for embed_fn, funcs in zip(
131
+ self.function_embed, function_tokens.unbind(-1)
132
+ )
133
+ ],
134
+ -1,
135
+ )
136
+
137
+ # Residue embeds
138
+ B, L, N = residue_annotation_tokens.shape
139
+ residue_embed = self.residue_embed(
140
+ einops.rearrange(
141
+ residue_annotation_tokens, "B L N -> (B L) N", B=B, L=L, N=N
142
+ )
143
+ )
144
+ residue_embed = einops.rearrange(residue_embed, "(B L) D -> B L D", B=B, L=L)
145
+
146
+ return (
147
+ sequence_embed
148
+ + plddt_embed
149
+ + structure_per_res_plddt
150
+ + structure_embed
151
+ + ss8_embed
152
+ + sasa_embed
153
+ + function_embed
154
+ + residue_embed
155
+ )
156
+
157
+
158
+ class OutputHeads(nn.Module):
159
+ def __init__(self, d_model: int):
160
+ super().__init__()
161
+ self.sequence_head = RegressionHead(d_model, 64)
162
+ self.structure_head = RegressionHead(d_model, 4096)
163
+ self.ss8_head = RegressionHead(d_model, 8 + 3)
164
+ self.sasa_head = RegressionHead(d_model, 16 + 3)
165
+ self.function_head = RegressionHead(d_model, 260 * 8)
166
+ self.residue_head = RegressionHead(d_model, 1478)
167
+
168
+ def forward(self, x: torch.Tensor, embed: torch.Tensor) -> ESMOutput:
169
+ sequence_logits = self.sequence_head(x)
170
+ structure_logits = self.structure_head(x)
171
+ secondary_structure_logits = self.ss8_head(x)
172
+ sasa_logits = self.sasa_head(x)
173
+ function_logits = self.function_head(x)
174
+ function_logits = einops.rearrange(function_logits, "... (k v) -> ... k v", k=8)
175
+
176
+ residue_logits = self.residue_head(x)
177
+
178
+ return ESMOutput(
179
+ sequence_logits=sequence_logits,
180
+ structure_logits=structure_logits,
181
+ secondary_structure_logits=secondary_structure_logits,
182
+ sasa_logits=sasa_logits,
183
+ function_logits=function_logits,
184
+ residue_logits=residue_logits,
185
+ embeddings=embed,
186
+ )
187
+
188
+
189
+ class ESM3(nn.Module, ESM3InferenceClient):
190
+ """
191
+ ESM3 model implementation.
192
+
193
+ Args:
194
+ d_model (int): The dimensionality of the input and output feature vectors.
195
+ n_heads (int): The number of attention heads in the transformer layers.
196
+ v_heads (int): The number of attention heads in the variational transformer layers.
197
+ n_layers (int): The number of transformer layers.
198
+ """
199
+
200
+ def __init__(
201
+ self,
202
+ d_model: int,
203
+ n_heads: int,
204
+ v_heads: int,
205
+ n_layers: int,
206
+ structure_encoder_fn: Callable[[torch.device | str], StructureTokenEncoder],
207
+ structure_decoder_fn: Callable[[torch.device | str], StructureTokenDecoder],
208
+ function_decoder_fn: Callable[[torch.device | str], FunctionTokenDecoder],
209
+ tokenizers: TokenizerCollectionProtocol,
210
+ ):
211
+ super().__init__()
212
+ self.encoder = EncodeInputs(d_model)
213
+ self.transformer = TransformerStack(
214
+ d_model, n_heads, v_heads, n_layers, mask_and_zero_frameless=True
215
+ )
216
+ self.output_heads = OutputHeads(d_model)
217
+
218
+ self.structure_encoder_fn = structure_encoder_fn
219
+ self.structure_decoder_fn = structure_decoder_fn
220
+ self.function_decoder_fn = function_decoder_fn
221
+
222
+ self._structure_encoder = None
223
+ self._structure_decoder = None
224
+ self._function_decoder = None
225
+
226
+ self.tokenizers = tokenizers
227
+
228
+ @classmethod
229
+ def from_pretrained(
230
+ cls, model_name: str = ESM3_OPEN_SMALL, device: torch.device | None = None
231
+ ) -> ESM3:
232
+ from esm.pretrained import load_local_model
233
+
234
+ model_name = normalize_model_name(model_name)
235
+ if not model_name:
236
+ raise ValueError(f"Model name {model_name} is not a valid ESM3 model name.")
237
+ if device is None:
238
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
239
+ model = load_local_model(model_name, device=device)
240
+ if device.type != "cpu":
241
+ model = model.to(torch.bfloat16)
242
+ assert isinstance(model, ESM3)
243
+ return model
244
+
245
+ @property
246
+ def device(self):
247
+ return next(self.parameters()).device
248
+
249
+ @property
250
+ def raw_model(self):
251
+ return self
252
+
253
+ def get_structure_encoder(self) -> StructureTokenEncoder:
254
+ if self._structure_encoder is None:
255
+ self._structure_encoder = self.structure_encoder_fn(self.device)
256
+ return self._structure_encoder
257
+
258
+ def get_structure_decoder(self) -> StructureTokenDecoder:
259
+ if self._structure_decoder is None:
260
+ self._structure_decoder = self.structure_decoder_fn(self.device)
261
+ return self._structure_decoder
262
+
263
+ def get_function_decoder(self) -> FunctionTokenDecoder:
264
+ if self._function_decoder is None:
265
+ self._function_decoder = self.function_decoder_fn(self.device)
266
+ return self._function_decoder
267
+
268
+ def forward(
269
+ self,
270
+ *,
271
+ sequence_tokens: torch.Tensor | None = None,
272
+ structure_tokens: torch.Tensor | None = None,
273
+ ss8_tokens: torch.Tensor | None = None,
274
+ sasa_tokens: torch.Tensor | None = None,
275
+ function_tokens: torch.Tensor | None = None,
276
+ residue_annotation_tokens: torch.Tensor | None = None,
277
+ average_plddt: torch.Tensor | None = None,
278
+ per_res_plddt: torch.Tensor | None = None,
279
+ structure_coords: torch.Tensor | None = None,
280
+ chain_id: torch.Tensor | None = None,
281
+ sequence_id: torch.Tensor | None = None,
282
+ ) -> ESMOutput:
283
+ """
284
+ Performs forward pass through the ESM3 model. Check utils to see how to tokenize inputs from raw data.
285
+
286
+ Args:
287
+ sequence_tokens (torch.Tensor, optional): The amino acid tokens.
288
+ structure_tokens (torch.Tensor, optional): The structure tokens.
289
+ ss8_tokens (torch.Tensor, optional): The secondary structure tokens.
290
+ sasa_tokens (torch.Tensor, optional): The solvent accessible surface area tokens.
291
+ function_tokens (torch.Tensor, optional): The function tokens.
292
+ residue_annotation_tokens (torch.Tensor, optional): The residue annotation tokens.
293
+ average_plddt (torch.Tensor, optional): The average plddt across the entire sequence.
294
+ per_res_plddt (torch.Tensor, optional): The per residue plddt, if you want to specify exact plddts, use this,
295
+ otherwise, use average_plddt.
296
+ structure_coords (torch.Tensor, optional): The structure coordinates, in the form of (B, L, 3, 3).
297
+ chain_id (torch.Tensor, optional): The chain ID
298
+ sequence_id (torch.Tensor, optional): The sequence ID.
299
+
300
+ Returns:
301
+ ESMOutput: The output of the ESM3 model.
302
+
303
+ Raises:
304
+ ValueError: If at least one of the inputs is None.
305
+
306
+ """
307
+ # Reasonable defaults:
308
+ try:
309
+ L, device = next(
310
+ (x.shape[1], x.device)
311
+ for x in [
312
+ sequence_tokens,
313
+ structure_tokens,
314
+ ss8_tokens,
315
+ sasa_tokens,
316
+ structure_coords,
317
+ function_tokens,
318
+ residue_annotation_tokens,
319
+ ]
320
+ if x is not None
321
+ )
322
+ except StopIteration:
323
+ raise ValueError("At least one of the inputs must be non-None")
324
+
325
+ t = self.tokenizers
326
+ defaults = lambda x, tok: (
327
+ torch.full((1, L), tok, dtype=torch.long, device=device) if x is None else x
328
+ )
329
+ sequence_tokens = defaults(sequence_tokens, t.sequence.mask_token_id)
330
+ ss8_tokens = defaults(ss8_tokens, C.SS8_PAD_TOKEN)
331
+ sasa_tokens = defaults(sasa_tokens, C.SASA_PAD_TOKEN)
332
+ average_plddt = defaults(average_plddt, 1).float()
333
+ per_res_plddt = defaults(per_res_plddt, 0).float()
334
+ chain_id = defaults(chain_id, 0)
335
+
336
+ if residue_annotation_tokens is None:
337
+ residue_annotation_tokens = torch.full(
338
+ (1, L, 16), C.RESIDUE_PAD_TOKEN, dtype=torch.long, device=device
339
+ )
340
+
341
+ if function_tokens is None:
342
+ function_tokens = torch.full(
343
+ (1, L, 8), C.INTERPRO_PAD_TOKEN, dtype=torch.long, device=device
344
+ )
345
+
346
+ if structure_coords is None:
347
+ structure_coords = torch.full(
348
+ (1, L, 3, 3), float("nan"), dtype=torch.float, device=device
349
+ )
350
+
351
+ structure_coords = structure_coords[
352
+ ..., :3, :
353
+ ] # In case we pass in an atom14 or atom37 repr
354
+ affine, affine_mask = build_affine3d_from_coordinates(structure_coords)
355
+
356
+ structure_tokens = defaults(structure_tokens, C.STRUCTURE_MASK_TOKEN)
357
+ assert structure_tokens is not None
358
+ structure_tokens = (
359
+ structure_tokens.masked_fill(structure_tokens == -1, C.STRUCTURE_MASK_TOKEN)
360
+ .masked_fill(sequence_tokens == C.SEQUENCE_BOS_TOKEN, C.STRUCTURE_BOS_TOKEN)
361
+ .masked_fill(sequence_tokens == C.SEQUENCE_PAD_TOKEN, C.STRUCTURE_PAD_TOKEN)
362
+ .masked_fill(sequence_tokens == C.SEQUENCE_EOS_TOKEN, C.STRUCTURE_EOS_TOKEN)
363
+ .masked_fill(
364
+ sequence_tokens == C.SEQUENCE_CHAINBREAK_TOKEN,
365
+ C.STRUCTURE_CHAINBREAK_TOKEN,
366
+ )
367
+ )
368
+
369
+ x = self.encoder(
370
+ sequence_tokens,
371
+ structure_tokens,
372
+ average_plddt,
373
+ per_res_plddt,
374
+ ss8_tokens,
375
+ sasa_tokens,
376
+ function_tokens,
377
+ residue_annotation_tokens,
378
+ )
379
+ x, embedding, hidden = self.transformer(
380
+ x, sequence_id, affine, affine_mask, chain_id
381
+ )
382
+ return self.output_heads(x, embedding), hidden # MODIFIED FOR DYNA-1
383
+
384
+ # The following methods are for the ESM3InferenceClient interface
385
+ def generate(self, input: ProteinType, config: GenerationConfig) -> ProteinType:
386
+ """Wrap around batched generation."""
387
+ proteins = self.batch_generate([input], [config])
388
+ assert len(proteins) == 1
389
+ return proteins[0]
390
+
391
+ def batch_generate(
392
+ self, inputs: list[ProteinType], configs: list[GenerationConfig]
393
+ ) -> list[ProteinType]:
394
+ assert len(inputs) == len(
395
+ configs
396
+ ), "Must have the same number of prompts and configs."
397
+
398
+ if inputs == []:
399
+ # Nothing to do.
400
+ return []
401
+
402
+ # Make sure prompts are of the same type.
403
+ t = type(inputs[0])
404
+ for i in range(1, len(inputs)):
405
+ assert isinstance(inputs[i], t), (
406
+ "Prompts must have the same type. Got "
407
+ f"{t.__name__ and type(inputs[i]).__name__} instead."
408
+ )
409
+
410
+ if isinstance(inputs[0], ESMProtein):
411
+ return iterative_sampling_raw(self, inputs, configs) # type: ignore
412
+ elif isinstance(inputs[0], ESMProteinTensor):
413
+ return iterative_sampling_tokens(
414
+ self,
415
+ inputs, # type: ignore
416
+ configs,
417
+ self.tokenizers, # type: ignore
418
+ )
419
+ else:
420
+ raise ValueError("Input must be an ESMProtein or ESMProteinTensor")
421
+
422
+ def encode(self, input: ESMProtein) -> ESMProteinTensor:
423
+ input = attr.evolve(input) # Make a copy
424
+
425
+ sequence_tokens = None
426
+ structure_tokens = None
427
+ secondary_structure_tokens = None
428
+ sasa_tokens = None
429
+ function_tokens = None
430
+ residue_annotation_tokens = None
431
+
432
+ coordinates = None
433
+
434
+ if input.sequence is not None:
435
+ sequence_tokens = encoding.tokenize_sequence(
436
+ input.sequence, self.tokenizers.sequence, add_special_tokens=True
437
+ )
438
+ if input.secondary_structure is not None:
439
+ secondary_structure_tokens = encoding.tokenize_secondary_structure(
440
+ input.secondary_structure,
441
+ self.tokenizers.secondary_structure,
442
+ add_special_tokens=True,
443
+ )
444
+ if input.sasa is not None:
445
+ sasa_tokens = encoding.tokenize_sasa(
446
+ input.sasa, self.tokenizers.sasa, add_special_tokens=True
447
+ )
448
+
449
+ # Infer input length
450
+ sequence_length = -1
451
+ if sequence_tokens is not None:
452
+ sequence_length = len(sequence_tokens)
453
+ elif secondary_structure_tokens is not None:
454
+ sequence_length = len(secondary_structure_tokens)
455
+ elif sasa_tokens is not None:
456
+ sequence_length = len(sasa_tokens)
457
+
458
+ # Try to infer input length from structure data
459
+ if input.coordinates is not None:
460
+ coordinates, _, structure_tokens = encoding.tokenize_structure(
461
+ input.coordinates,
462
+ self.get_structure_encoder(),
463
+ structure_tokenizer=self.tokenizers.structure,
464
+ reference_sequence=input.sequence or "",
465
+ add_special_tokens=True,
466
+ )
467
+ if sequence_length == -1:
468
+ sequence_length = len(structure_tokens)
469
+
470
+ if sequence_length == -1:
471
+ raise ValueError(
472
+ "Cannot infer input length from input data. Please provide one of: sequence, structure, secondary_structure, sasa.\n"
473
+ "To condition on sequence length only, use ESM3LocalInferenceClient.get_default_sequence(sequence_length) to generate a default sequence input."
474
+ )
475
+
476
+ # Function and Residue annotations
477
+ if input.function_annotations is not None:
478
+ if input.sequence is None:
479
+ reference_sequence = encoding.get_default_sequence(sequence_length - 2)
480
+ else:
481
+ reference_sequence = input.sequence
482
+ (function_tokens, residue_annotation_tokens) = (
483
+ encoding.tokenize_function_annotations(
484
+ input.function_annotations,
485
+ reference_sequence=reference_sequence,
486
+ function_tokenizer=self.tokenizers.function,
487
+ residue_annotation_tokenizer=self.tokenizers.residue_annotations,
488
+ add_special_tokens=True,
489
+ )
490
+ )
491
+
492
+ return ESMProteinTensor(
493
+ sequence=sequence_tokens,
494
+ structure=structure_tokens,
495
+ secondary_structure=secondary_structure_tokens,
496
+ sasa=sasa_tokens,
497
+ function=function_tokens,
498
+ residue_annotations=residue_annotation_tokens,
499
+ coordinates=coordinates,
500
+ ).to(next(self.parameters()).device)
501
+
502
+ def decode(self, input: ESMProteinTensor) -> ESMProtein:
503
+ return decode_protein_tensor(
504
+ input=input,
505
+ tokenizers=self.tokenizers,
506
+ structure_token_decoder=self.get_structure_decoder(),
507
+ function_token_decoder=self.get_function_decoder(),
508
+ )
509
+
510
+ def logits(
511
+ self,
512
+ input: ESMProteinTensor | _BatchedESMProteinTensor,
513
+ config: LogitsConfig = LogitsConfig(),
514
+ ) -> LogitsOutput:
515
+ if not isinstance(input, _BatchedESMProteinTensor):
516
+ # Create batch dimension if necessary.
517
+ input = _BatchedESMProteinTensor.from_protein_tensor(input)
518
+
519
+ device = torch.device(input.device)
520
+
521
+ # Default plddt conditioning for inference. 1s where coordinates are provided.
522
+ if input.coordinates is None:
523
+ per_res_plddt = None
524
+ else:
525
+ # 1.0 if all coordinates at specific indices have valid non-nan values.
526
+ per_res_plddt = input.coordinates.isfinite().all(dim=-1).any(dim=-1).float()
527
+
528
+ with (
529
+ torch.no_grad(), # Assume no gradients for now...
530
+ torch.autocast(enabled=True, device_type=device.type, dtype=torch.bfloat16) # type: ignore
531
+ if device.type == "cuda"
532
+ else contextlib.nullcontext(),
533
+ ):
534
+ output = self.forward(
535
+ sequence_tokens=input.sequence,
536
+ structure_tokens=input.structure,
537
+ ss8_tokens=input.secondary_structure,
538
+ sasa_tokens=input.sasa,
539
+ function_tokens=input.function,
540
+ residue_annotation_tokens=input.residue_annotations,
541
+ average_plddt=torch.tensor(1.0, device=input.device),
542
+ per_res_plddt=per_res_plddt,
543
+ structure_coords=input.coordinates,
544
+ chain_id=None,
545
+ sequence_id=None,
546
+ )
547
+
548
+ output = ESMOutput(
549
+ **{k: v.to(device).to(torch.float32) for k, v in vars(output).items()}
550
+ )
551
+
552
+ return LogitsOutput(
553
+ logits=ForwardTrackData(
554
+ sequence=output.sequence_logits if config.sequence else None,
555
+ structure=output.structure_logits if config.structure else None,
556
+ secondary_structure=output.secondary_structure_logits
557
+ if config.secondary_structure
558
+ else None,
559
+ sasa=output.sasa_logits if config.sasa else None,
560
+ function=output.function_logits if config.function else None,
561
+ ),
562
+ residue_annotation_logits=output.residue_logits
563
+ if config.residue_annotations
564
+ else None,
565
+ embeddings=output.embeddings if config.return_embeddings else None,
566
+ )
567
+
568
+ def forward_and_sample(
569
+ self, input: ESMProteinTensor, sampling_configuration: SamplingConfig
570
+ ) -> ForwardAndSampleOutput:
571
+ validate_sampling_config(sampling_configuration, on_invalid="warn")
572
+
573
+ protein_tensor = attr.evolve(input) # Make a copy
574
+
575
+ device = next(self.parameters()).device
576
+
577
+ sampling_config = sampling_configuration
578
+ if sampling_config is None:
579
+ sampling_config = get_default_sampling_config(self.tokenizers)
580
+
581
+ # Initialize default values for missing tracks
582
+ default_protein_tensor = ESMProteinTensor.empty(
583
+ len(input) - 2, tokenizers=self.tokenizers, device=input.device
584
+ )
585
+ for track in attr.fields(ESMProteinTensor):
586
+ if getattr(protein_tensor, track.name, None) is None:
587
+ setattr(
588
+ protein_tensor,
589
+ track.name,
590
+ getattr(default_protein_tensor, track.name, None),
591
+ )
592
+
593
+ if len(protein_tensor) <= 0:
594
+ raise ValueError("No input data provided")
595
+
596
+ # Move input protein to proper device.
597
+ batched_protein = _BatchedESMProteinTensor.from_protein_tensor(protein_tensor)
598
+ batched_protein.to(device)
599
+
600
+ logits_output: LogitsOutput = _batch_forward(self, batched_protein)
601
+ forward_and_sample_out: ForwardAndSampleOutput = _sample_per_prompt(
602
+ batched_protein, logits_output, sampling_config, self.tokenizers
603
+ )
604
+
605
+ # There is only 1 prompt to sample for.
606
+ return _slice_tensor_dataclass(forward_and_sample_out, 0)
Dyna-1/esm/models/esmc.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+
5
+ import attr
6
+ import torch
7
+ import torch.nn as nn
8
+ from attr import dataclass
9
+
10
+ from esm.layers.regression_head import RegressionHead
11
+ from esm.layers.transformer_stack import TransformerStack
12
+ from esm.sdk.api import (
13
+ ESMCInferenceClient,
14
+ ESMProtein,
15
+ ESMProteinTensor,
16
+ ForwardTrackData,
17
+ LogitsConfig,
18
+ LogitsOutput,
19
+ )
20
+ from esm.tokenization import EsmSequenceTokenizer
21
+ from esm.utils import encoding
22
+ from esm.utils.constants.models import ESMC_600M
23
+ from esm.utils.decoding import decode_sequence
24
+ from esm.utils.misc import stack_variable_length_tensors
25
+ from esm.utils.sampling import _BatchedESMProteinTensor
26
+
27
+
28
+ @dataclass
29
+ class ESMCOutput:
30
+ sequence_logits: torch.Tensor
31
+ embeddings: torch.Tensor | None
32
+ hidden_states: torch.Tensor | None
33
+
34
+
35
+ class ESMC(nn.Module, ESMCInferenceClient):
36
+ """
37
+ ESMC model implementation.
38
+
39
+ Args:
40
+ d_model (int): The dimensionality of the input and output feature vectors.
41
+ n_heads (int): The number of attention heads in the transformer layers.
42
+ n_layers (int): The number of transformer layers.
43
+ """
44
+
45
+ def __init__(
46
+ self, d_model: int, n_heads: int, n_layers: int, tokenizer: EsmSequenceTokenizer
47
+ ):
48
+ super().__init__()
49
+ self.embed = nn.Embedding(64, d_model)
50
+ self.transformer = TransformerStack(
51
+ d_model, n_heads, None, n_layers, n_layers_geom=0
52
+ )
53
+ self.sequence_head = RegressionHead(d_model, 64)
54
+ self.tokenizer = tokenizer
55
+
56
+ @classmethod
57
+ def from_pretrained(
58
+ cls, model_name: str = ESMC_600M, device: torch.device | None = None
59
+ ) -> ESMC:
60
+ from esm.pretrained import load_local_model
61
+
62
+ if device is None:
63
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
64
+ model = load_local_model(model_name, device=device)
65
+ if device.type != "cpu":
66
+ model = model.to(torch.bfloat16)
67
+ assert isinstance(model, ESMC)
68
+ return model
69
+
70
+ @property
71
+ def device(self):
72
+ return next(self.parameters()).device
73
+
74
+ @property
75
+ def raw_model(self):
76
+ return self
77
+
78
+ def _tokenize(self, sequence: list[str]) -> torch.Tensor:
79
+ pad = self.tokenizer.pad_token_id
80
+ assert pad is not None
81
+ return stack_variable_length_tensors(
82
+ [
83
+ encoding.tokenize_sequence(x, self.tokenizer, add_special_tokens=True)
84
+ for x in sequence
85
+ ],
86
+ constant_value=pad,
87
+ ).to(next(self.parameters()).device)
88
+
89
+ def _detokenize(self, sequence: torch.Tensor) -> list[str]:
90
+ pad = self.tokenizer.pad_token_id
91
+ assert pad is not None
92
+ assert sequence.ndim == 2
93
+ return [decode_sequence(x[x != pad][1:-1], self.tokenizer) for x in sequence]
94
+
95
+ def forward(
96
+ self,
97
+ sequence_tokens: torch.Tensor | None = None,
98
+ sequence_id: torch.Tensor | None = None,
99
+ ) -> ESMCOutput:
100
+ """
101
+ Performs forward pass through the ESMC model. Check utils to see how to tokenize inputs from raw data.
102
+
103
+ Args:
104
+ sequence_tokens (torch.Tensor, optional): The amino acid tokens.
105
+ sequence_id (torch.Tensor, optional): The sequence ID.
106
+
107
+ Returns:
108
+ ESMCOutput: The output of the ESMC model.
109
+
110
+ """
111
+ if sequence_id is None:
112
+ sequence_id = sequence_tokens == self.tokenizer.pad_token_id
113
+
114
+ x = self.embed(sequence_tokens)
115
+ x, _, hidden = self.transformer(x, sequence_id=sequence_id)
116
+ sequence_logits = self.sequence_head(x)
117
+ output = ESMCOutput(
118
+ sequence_logits=sequence_logits, embeddings=x, hidden_states=hidden
119
+ )
120
+ return output, hidden #MODIFIED FOR DYNA-1
121
+
122
+ def encode(self, input: ESMProtein) -> ESMProteinTensor:
123
+ input = attr.evolve(input) # Make a copy
124
+ sequence_tokens = None
125
+
126
+ if input.sequence is not None:
127
+ sequence_tokens = self._tokenize([input.sequence])[0]
128
+ return ESMProteinTensor(sequence=sequence_tokens).to(
129
+ next(self.parameters()).device
130
+ )
131
+
132
+ def decode(self, input: ESMProteinTensor) -> ESMProtein:
133
+ input = attr.evolve(input) # Make a copy
134
+
135
+ assert input.sequence is not None
136
+ sequence = self._detokenize(input.sequence)[0]
137
+
138
+ return ESMProtein(sequence=sequence)
139
+
140
+ def logits(
141
+ self,
142
+ input: ESMProteinTensor | _BatchedESMProteinTensor,
143
+ config: LogitsConfig = LogitsConfig(),
144
+ ) -> LogitsOutput:
145
+ if not isinstance(input, _BatchedESMProteinTensor):
146
+ # Create batch dimension if necessary.
147
+ input = _BatchedESMProteinTensor.from_protein_tensor(input)
148
+
149
+ device = torch.device(input.device)
150
+
151
+ with (
152
+ torch.no_grad(),
153
+ torch.autocast(enabled=True, device_type=device.type, dtype=torch.bfloat16) # type: ignore
154
+ if device.type == "cuda"
155
+ else contextlib.nullcontext(),
156
+ ):
157
+ output = self.forward(sequence_tokens=input.sequence)
158
+
159
+ return LogitsOutput(
160
+ logits=ForwardTrackData(
161
+ sequence=output.sequence_logits if config.sequence else None
162
+ ),
163
+ embeddings=output.embeddings if config.return_embeddings else None,
164
+ )
Dyna-1/esm/models/function_decoder.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Function Token Decoder."""
2
+
3
+ from collections import defaultdict
4
+ from dataclasses import dataclass, field
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from cloudpathlib import AnyPath
12
+
13
+ from esm.layers.regression_head import RegressionHead
14
+ from esm.layers.transformer_stack import TransformerStack
15
+ from esm.tokenization.function_tokenizer import (
16
+ InterProQuantizedTokenizer,
17
+ )
18
+ from esm.utils.constants import esm3 as C
19
+ from esm.utils.misc import merge_annotations, merge_ranges
20
+ from esm.utils.types import FunctionAnnotation
21
+
22
+
23
+ @dataclass(frozen=True)
24
+ class FunctionTokenDecoderConfig:
25
+ """Configures function token decoder."""
26
+
27
+ # Embedding dimension of decoder.
28
+ d_model: int = 1024
29
+ # Number of attention heads of decoder.
30
+ n_heads: int = 8
31
+ # Number of layers of decoder.
32
+ n_layers: int = 3
33
+ # Number of integer values that function tokens may assume.
34
+ function_token_vocab_size: int = 260
35
+ # Number of function tokens at each position.
36
+ function_token_depth: int = 8
37
+ # Number of InterPro labels that can be decoded.
38
+ num_interpro_classes: int = 29026
39
+ # Number of function keywords that can be decoded.
40
+ keyword_vocabulary_size: int = 58641
41
+ # List of supported InterPro ids.
42
+ interpro_entry_list: str = field(default_factory=lambda: str(C.INTERPRO_ENTRY))
43
+ # Path to keywords vocabulary.
44
+ keyword_vocabulary_path: str = field(
45
+ default_factory=lambda: str(C.data_root("esm3") / C.KEYWORDS_VOCABULARY)
46
+ )
47
+ # Whether to unpack LSH bits into single-bit tokens.
48
+ unpack_lsh_bits: bool = True
49
+ # The number of special tokens in the function tokenizer vocabulary which come
50
+ # before the LSH tokens.
51
+ num_special_tokens: int = 4
52
+ # The number of bits per LSH token in the function tokenizer.
53
+ bits_per_token: int = 8
54
+
55
+
56
+ class FunctionTokenDecoder(nn.Module):
57
+ def __init__(self, config: FunctionTokenDecoderConfig | None = None):
58
+ """Constructs function token decoder."""
59
+ super().__init__()
60
+ if config is None:
61
+ config = FunctionTokenDecoderConfig()
62
+ self.config = config
63
+
64
+ # Get the supported set of InterPro ids.
65
+ with AnyPath(config.interpro_entry_list).open("r") as f:
66
+ df = pd.read_csv(f, sep="\t")
67
+ self.interpro_ids = sorted(df.ENTRY_AC)
68
+ self.interpro2index = {
69
+ interpro_id: i for i, interpro_id in enumerate(self.interpro_ids)
70
+ }
71
+ assert len(self.interpro_ids) == config.num_interpro_classes
72
+
73
+ with AnyPath(config.keyword_vocabulary_path).open("r") as f:
74
+ self.keywords_vocabulary: list[str] = list(f.read().strip().split("\n"))
75
+ assert len(self.keywords_vocabulary) == config.keyword_vocabulary_size
76
+
77
+ if config.unpack_lsh_bits:
78
+ vocab_size = 2 * config.function_token_depth * config.bits_per_token
79
+ else:
80
+ # Function-token id's re-use the same token ids at each position along the depth
81
+ # dimension, despite distinct meanings. The decoder should take this into
82
+ # account so create distinct embeddings for tokens at each position.
83
+ vocab_size = (
84
+ self.config.function_token_depth * self.config.function_token_vocab_size
85
+ )
86
+
87
+ self.embedding = nn.Embedding(
88
+ # Function-token id's re-use the same token ids at each position along the
89
+ # depth dimension, despite distinct meanings. The decoder should take this
90
+ # into account so create distinct embeddings for tokens at each position.
91
+ num_embeddings=(vocab_size),
92
+ embedding_dim=config.d_model,
93
+ )
94
+ self.decoder = TransformerStack(
95
+ d_model=config.d_model,
96
+ n_heads=config.n_heads,
97
+ v_heads=None,
98
+ n_layers=config.n_layers,
99
+ n_layers_geom=0,
100
+ scale_residue=False,
101
+ bias=True,
102
+ qk_layernorm=False,
103
+ ffn_type="gelu",
104
+ expansion_ratio=4,
105
+ )
106
+ self.heads = nn.ModuleDict(
107
+ {
108
+ # Binary classification head predicting which keywords are present.
109
+ "keyword_logits": RegressionHead(
110
+ d_model=config.d_model,
111
+ output_dim=config.keyword_vocabulary_size,
112
+ hidden_dim=4 * config.d_model,
113
+ ),
114
+ # Regresses the TF-IDF value of each present keyword.
115
+ "keyword_tfidf": RegressionHead(
116
+ d_model=config.d_model,
117
+ output_dim=config.keyword_vocabulary_size,
118
+ hidden_dim=4 * config.d_model,
119
+ ),
120
+ # Predicts which InterPro annotations are present.
121
+ "interpro_logits": RegressionHead(
122
+ d_model=config.d_model,
123
+ output_dim=config.num_interpro_classes,
124
+ hidden_dim=4 * config.d_model,
125
+ ),
126
+ }
127
+ )
128
+
129
+ def forward(self, token_ids: torch.Tensor) -> dict[str, torch.Tensor]:
130
+ """Forward pass through function token decoder.
131
+
132
+ Args:
133
+ token_ids: <int>[batch_size, function_token_depth] batch of function tokens
134
+ ids to decode.
135
+ Returns:
136
+ interpro_logits: binary classification logits tensor of shape
137
+ <float>[batch_size, num_interpro_classes]
138
+ """
139
+ assert token_ids.ndim == 2
140
+ assert token_ids.shape[1] == self.config.function_token_depth
141
+ batch_size, depth = token_ids.shape
142
+
143
+ if self.config.unpack_lsh_bits:
144
+ # Shift values into [0, 2^bits/token)
145
+ lsh_bits = token_ids - self.config.num_special_tokens
146
+ # extract each bit. (hob stands for highest-order bit)
147
+ bits = torch.concat(
148
+ [
149
+ torch.bitwise_and(lsh_bits, 1 << hob).gt(0).to(torch.int32)
150
+ for hob in range(self.config.bits_per_token)
151
+ ],
152
+ dim=1,
153
+ )
154
+ assert bits.shape == (batch_size, depth * self.config.bits_per_token)
155
+
156
+ # Shift each bit into individual vocabulary ranges, so they get distinct
157
+ # embeddings.
158
+ vocab_offsets = 2 * torch.arange(
159
+ depth * self.config.bits_per_token, device=token_ids.device
160
+ )
161
+ inputs = vocab_offsets[None, :] + bits
162
+
163
+ # zero-out special tokens, i.e. non LSH tokens.
164
+ where_special = token_ids < self.config.num_special_tokens
165
+ inputs = torch.where(where_special.any(dim=1, keepdim=True), 0, inputs)
166
+ else:
167
+ # Apply depth-position offset to use distinct vocabs. See __init__ for
168
+ # explaination.
169
+ vocab_offsets = self.config.function_token_vocab_size * torch.arange(
170
+ self.config.function_token_depth, device=token_ids.device
171
+ )
172
+ inputs = token_ids + vocab_offsets[None, :]
173
+
174
+ embed = self.embedding(inputs)
175
+ encoding, _, _ = self.decoder(embed)
176
+ pooled = torch.mean(encoding, dim=1)
177
+
178
+ return {name: head(pooled) for name, head in self.heads.items()}
179
+
180
+ @property
181
+ def device(self) -> torch.device:
182
+ return next(self.parameters()).device
183
+
184
+ def decode(
185
+ self,
186
+ function_token_ids: torch.Tensor,
187
+ tokenizer: InterProQuantizedTokenizer,
188
+ decode_annotations: bool = True,
189
+ annotation_threshold: float = 0.1,
190
+ decode_keywords=True,
191
+ keywords_threshold: float = 0.5,
192
+ annotation_min_length: int | None = 5,
193
+ annotation_gap_merge_max: int | None = 3,
194
+ ):
195
+ """Decodes function tokens into predicted annotations and keywords.
196
+
197
+ Args:
198
+ function_token_ids: <int>[length, depth] function token ids. NOTE:
199
+ without <bos>/<eos> prefix
200
+ tokenizer: function tokenizer.
201
+ decode_annotations: whether to decode InterPro annotations.
202
+ annotation_threshold: threshold for emitting a function annotation.
203
+ decode_keywords: whether to decode function keywords.
204
+ keywords_threshold: threshold for emitting a keyword.
205
+ annotation_min_length: optional minimum length of predicted annotations for
206
+ size filtering.
207
+ annotation_gap_merge_max: optional merge adjacent annotation of the same type
208
+ Returns:
209
+ Decoder outputs:
210
+ - "interpro_logits": <float>[length, num_interpro] predicted interpro logits.
211
+ - "interpro_preds": <bool>[length, num_interpro] predicted intepro labels.
212
+ - "interpro_annotations": list[FunctionAnnotation] predicted InterPro
213
+ annotations
214
+ - "keyword_logits": <float>[length, keyword_vocabulary] binary prediciton
215
+ logits for keywrods.
216
+ - "function_keywords": list[FunctionAnnotation] predicted function keyword
217
+ ranges.
218
+ """
219
+ assert function_token_ids.ndim == 2
220
+ assert function_token_ids.shape[1] == tokenizer.depth
221
+ assert self.config.function_token_depth == tokenizer.depth
222
+
223
+ outputs = {}
224
+
225
+ outputs = self(function_token_ids.to(self.device))
226
+
227
+ # Only decode in positions that have function tokens.
228
+ where_decode = torch.all(
229
+ (function_token_ids != tokenizer.vocab_to_index["<pad>"])
230
+ & (function_token_ids != tokenizer.vocab_to_index["<none>"])
231
+ & (function_token_ids != tokenizer.vocab_to_index["<unk>"]),
232
+ dim=1,
233
+ )
234
+
235
+ # Decode InterPro annotations ranges.
236
+ interpro_preds = F.sigmoid(outputs["interpro_logits"])
237
+ interpro_preds = interpro_preds >= annotation_threshold
238
+ interpro_preds[~where_decode, :] = False
239
+ outputs["interpro_preds"] = interpro_preds
240
+ if decode_annotations:
241
+ annotations: list[FunctionAnnotation] = []
242
+ preds: np.ndarray = interpro_preds.detach().cpu().numpy()
243
+ for position_index, class_index in zip(*preds.nonzero()):
244
+ interpro_id = self.interpro_ids[class_index]
245
+ annotation = FunctionAnnotation(
246
+ label=interpro_id,
247
+ start=position_index, # one-index inclusive (BOS shifts indexes +1)
248
+ end=position_index, # one-index inclusive
249
+ )
250
+ annotations.append(annotation)
251
+
252
+ annotations = merge_annotations(
253
+ annotations, merge_gap_max=annotation_gap_merge_max
254
+ )
255
+
256
+ # Drop very small annotations.
257
+ if annotation_min_length is not None:
258
+ annotations = [
259
+ annotation
260
+ for annotation in annotations
261
+ if annotation.end - annotation.start + 1 >= annotation_min_length
262
+ ]
263
+
264
+ outputs["interpro_annotations"] = annotations
265
+
266
+ # Decode function keyword ranges.
267
+ keyword_logits = outputs["keyword_logits"]
268
+ keyword_logits[~where_decode, :] = -torch.inf
269
+ if decode_keywords:
270
+ keyword_preds = F.sigmoid(keyword_logits) >= keywords_threshold
271
+ outputs["function_keywords"] = self._preds_to_keywords(
272
+ keyword_preds.detach().cpu().numpy()
273
+ )
274
+
275
+ return outputs
276
+
277
+ def _preds_to_keywords(self, keyword_preds: np.ndarray) -> list[FunctionAnnotation]:
278
+ """Converts output log-TFDF to predicted keywords over the sequence.
279
+
280
+ Args:
281
+ keyword_precs: <bool>[length, keyword_vocab] positional predictions of
282
+ function keywords from the keyword prediction head.
283
+ Returns:
284
+ Non-overlapping keyword annotated ranges along the sequence. Note that indices
285
+ will index into the *sequence*, not the function token array which has a
286
+ <pad> prefix.
287
+ """
288
+ assert keyword_preds.ndim == 2
289
+ assert keyword_preds.shape[1] == self.config.keyword_vocabulary_size
290
+
291
+ keyword_positions: dict[str, list[range]] = defaultdict(list)
292
+ for position, keyword_id in zip(*np.nonzero(keyword_preds)):
293
+ keyword = self.keywords_vocabulary[keyword_id]
294
+ keyword_positions[keyword].append(range(position, position + 1))
295
+
296
+ annotations: list[FunctionAnnotation] = []
297
+ for keyword, ranges in keyword_positions.items():
298
+ for range_ in merge_ranges(ranges):
299
+ annotation = FunctionAnnotation(
300
+ label=keyword,
301
+ start=range_.start, # one-index inclusive (BOS shifts indexes +1)
302
+ end=range_.stop - 1, # one-index exclusive -> one-index inclusive
303
+ )
304
+ annotations.append(annotation)
305
+
306
+ return annotations
Dyna-1/esm/models/vqvae.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from esm.layers.blocks import UnifiedTransformerBlock
5
+ from esm.layers.codebook import EMACodebook
6
+ from esm.layers.structure_proj import Dim6RotStructureHead
7
+ from esm.layers.transformer_stack import TransformerStack
8
+ from esm.utils.constants import esm3 as C
9
+ from esm.utils.misc import knn_graph
10
+ from esm.utils.structure.affine3d import (
11
+ Affine3D,
12
+ build_affine3d_from_coordinates,
13
+ )
14
+ from esm.utils.structure.predicted_aligned_error import (
15
+ compute_predicted_aligned_error,
16
+ compute_tm,
17
+ )
18
+
19
+
20
+ class RelativePositionEmbedding(nn.Module):
21
+ """
22
+ Embedding layer for relative position embeddings. `bins` is the number of positions relative
23
+ to the query position that are considered before clipping. For instance, if `bins=10`, then
24
+ the relative position embedding will have 21 positions, [-10, 10].
25
+ """
26
+
27
+ def __init__(self, bins, embedding_dim, init_std=0.02):
28
+ super().__init__()
29
+ self.bins = bins
30
+
31
+ self.embedding = torch.nn.Embedding(2 * bins + 2, embedding_dim)
32
+ self.embedding.weight.data.normal_(0, init_std)
33
+
34
+ def forward(self, query_residue_index, key_residue_index):
35
+ """
36
+ Input:
37
+ query_residue_index: (B, ) tensor of source indices (dytpe=torch.long)
38
+ key_residue_index: (B, L) tensor of target indices (dytpe=torch.long)
39
+ Output:
40
+ embeddings: B x L x embedding_dim tensor of embeddings
41
+ """
42
+
43
+ assert query_residue_index.dtype == torch.long
44
+ assert key_residue_index.dtype == torch.long
45
+ assert query_residue_index.ndim == 1
46
+ assert key_residue_index.ndim == 2
47
+
48
+ diff = key_residue_index - query_residue_index.unsqueeze(1)
49
+ diff = diff.clamp(-self.bins, self.bins)
50
+ diff = diff + self.bins + 1 # add 1 to adjust for padding index
51
+ output = self.embedding(diff)
52
+ return output
53
+
54
+
55
+ class PairwisePredictionHead(nn.Module):
56
+ def __init__(
57
+ self,
58
+ input_dim: int,
59
+ downproject_dim: int,
60
+ hidden_dim: int,
61
+ n_bins: int,
62
+ bias: bool = True,
63
+ pairwise_state_dim: int = 0,
64
+ ):
65
+ super().__init__()
66
+ self.downproject = nn.Linear(input_dim, downproject_dim, bias=bias)
67
+ self.linear1 = nn.Linear(
68
+ downproject_dim + pairwise_state_dim, hidden_dim, bias=bias
69
+ )
70
+ self.activation_fn = nn.GELU()
71
+ self.norm = nn.LayerNorm(hidden_dim)
72
+ self.linear2 = nn.Linear(hidden_dim, n_bins, bias=bias)
73
+
74
+ def forward(self, x, pairwise: torch.Tensor | None = None):
75
+ """
76
+ Args:
77
+ x: [B x L x D]
78
+
79
+ Output:
80
+ [B x L x L x K]
81
+ """
82
+ x = self.downproject(x)
83
+ # Let x_i be a vector of size (B, D).
84
+ # Input is {x_1, ..., x_L} of size (B, L, D)
85
+ # Output is 2D where x_ij = cat([x_i * x_j, x_i - x_j])
86
+ q, k = x.chunk(2, dim=-1)
87
+
88
+ prod = q[:, None, :, :] * k[:, :, None, :]
89
+ diff = q[:, None, :, :] - k[:, :, None, :]
90
+ x_2d = [prod, diff]
91
+ if pairwise is not None:
92
+ x_2d.append(pairwise)
93
+ x = torch.cat(x_2d, dim=-1)
94
+ x = self.linear1(x)
95
+ x = self.activation_fn(x)
96
+ x = self.norm(x)
97
+ x = self.linear2(x)
98
+ return x
99
+
100
+
101
+ class RegressionHead(nn.Module):
102
+ def __init__(self, embed_dim: int, output_dim: int):
103
+ super().__init__()
104
+ self.dense = nn.Linear(embed_dim, embed_dim)
105
+ self.activation_fn = nn.GELU()
106
+ self.norm = nn.LayerNorm(embed_dim)
107
+ self.output = nn.Linear(embed_dim, output_dim)
108
+
109
+ def forward(self, features):
110
+ x = self.dense(features)
111
+ x = self.activation_fn(x)
112
+ x = self.norm(x)
113
+ x = self.output(x)
114
+ return x
115
+
116
+
117
+ class CategoricalMixture:
118
+ def __init__(self, param, bins=50, start=0, end=1):
119
+ # All tensors are of shape ..., bins.
120
+ self.logits = param
121
+ bins = torch.linspace(
122
+ start, end, bins + 1, device=self.logits.device, dtype=torch.float32
123
+ )
124
+ self.v_bins = (bins[:-1] + bins[1:]) / 2
125
+
126
+ def log_prob(self, true):
127
+ # Shapes are:
128
+ # self.probs: ... x bins
129
+ # true : ... (floating point # for target)
130
+ true_index = (
131
+ (true.unsqueeze(-1) - self.v_bins[[None] * true.ndim]).abs().argmin(-1)
132
+ )
133
+ nll = self.logits.log_softmax(-1)
134
+ return torch.take_along_dim(nll, true_index.unsqueeze(-1), dim=-1).squeeze(-1)
135
+
136
+ def mean(self):
137
+ return (
138
+ self.logits.to(self.v_bins.dtype).softmax(-1) @ self.v_bins.unsqueeze(1)
139
+ ).squeeze(-1)
140
+
141
+ def median(self):
142
+ return self.v_bins[self.logits.max(-1).indices]
143
+
144
+
145
+ class GeometricEncoderStack(TransformerStack):
146
+ def __init__(self, d_model, n_heads, v_heads, n_layers):
147
+ super().__init__(d_model, n_heads, v_heads, 0)
148
+ self.blocks = nn.ModuleList(
149
+ [
150
+ UnifiedTransformerBlock(
151
+ d_model,
152
+ n_heads,
153
+ v_heads=v_heads,
154
+ use_geom_attn=True,
155
+ use_plain_attn=False,
156
+ expansion_ratio=4,
157
+ bias=True,
158
+ )
159
+ for i in range(n_layers)
160
+ ]
161
+ )
162
+ self.norm = nn.Identity()
163
+
164
+
165
+ def batched_gather(data, inds, dim=0, no_batch_dims=0):
166
+ ranges = []
167
+ for i, s in enumerate(data.shape[:no_batch_dims]):
168
+ r = torch.arange(s)
169
+ r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
170
+ ranges.append(r)
171
+
172
+ remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)]
173
+ remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
174
+ ranges.extend(remaining_dims)
175
+ return data[ranges]
176
+
177
+
178
+ def node_gather(s: torch.Tensor, edges: torch.Tensor) -> torch.Tensor:
179
+ return batched_gather(s.unsqueeze(-3), edges, -2, no_batch_dims=len(s.shape) - 1)
180
+
181
+
182
+ class StructureTokenEncoder(nn.Module):
183
+ def __init__(self, d_model, n_heads, v_heads, n_layers, d_out, n_codes):
184
+ super().__init__()
185
+ # We only support fully-geometric structure token encoders for now...
186
+ # setting n_layers_geom to something that's not n_layers won't work because
187
+ # sequence ID isn't supported fully in this repo for plain-old transformers
188
+ self.transformer = GeometricEncoderStack(d_model, n_heads, v_heads, n_layers)
189
+ self.pre_vq_proj = nn.Linear(d_model, d_out)
190
+ self.codebook = EMACodebook(n_codes, d_out)
191
+ self.relative_positional_embedding = RelativePositionEmbedding(
192
+ 32, d_model, init_std=0.02
193
+ )
194
+ self.knn = 16
195
+
196
+ def encode_local_structure(
197
+ self,
198
+ coords: torch.Tensor,
199
+ affine: Affine3D,
200
+ attention_mask: torch.Tensor,
201
+ sequence_id: torch.Tensor | None,
202
+ affine_mask: torch.Tensor,
203
+ residue_index: torch.Tensor | None = None,
204
+ ):
205
+ """This function allows for a multi-layered encoder to encode tokens with a local receptive fields. The implementation is as follows:
206
+
207
+ 1. Starting with (B, L) frames, we find the KNN in structure space. This now gives us (B, L, K) where the last dimension is the local
208
+ neighborhood of all (B, L) residues.
209
+ 2. We reshape these frames to (B*L, K) so now we have a large batch of a bunch of local neighborhoods.
210
+ 3. Pass the (B*L, K) local neighborhoods through a stack of geometric reasoning blocks, effectively getting all to all communication between
211
+ all frames in the local neighborhood.
212
+ 4. This gives (B*L, K, d_model) embeddings, from which we need to get a single embedding per local neighborhood. We do this by simply
213
+ taking the embedding corresponding to the query node. This gives us (B*L, d_model) embeddings.
214
+ 5. Reshape back to (B, L, d_model) embeddings
215
+ """
216
+ assert coords.size(-1) == 3 and coords.size(-2) == 3, "need N, CA, C"
217
+ with torch.no_grad():
218
+ knn_edges, _ = self.find_knn_edges(
219
+ coords,
220
+ ~attention_mask,
221
+ coord_mask=affine_mask,
222
+ sequence_id=sequence_id,
223
+ knn=self.knn,
224
+ )
225
+ B, L, E = knn_edges.shape
226
+
227
+ affine_tensor = affine.tensor # for easier manipulation
228
+ T_D = affine_tensor.size(-1)
229
+ knn_affine_tensor = node_gather(affine_tensor, knn_edges)
230
+ knn_affine_tensor = knn_affine_tensor.view(-1, E, T_D).contiguous()
231
+ affine = Affine3D.from_tensor(knn_affine_tensor)
232
+ knn_sequence_id = (
233
+ node_gather(sequence_id.unsqueeze(-1), knn_edges).view(-1, E)
234
+ if sequence_id is not None
235
+ else torch.zeros(B * L, E, dtype=torch.int64, device=coords.device)
236
+ )
237
+ knn_affine_mask = node_gather(affine_mask.unsqueeze(-1), knn_edges).view(
238
+ -1, E
239
+ )
240
+ knn_chain_id = torch.zeros(
241
+ B * L, E, dtype=torch.int64, device=coords.device
242
+ )
243
+
244
+ if residue_index is None:
245
+ res_idxs = knn_edges.view(-1, E)
246
+ else:
247
+ res_idxs = node_gather(residue_index.unsqueeze(-1), knn_edges).view(
248
+ -1, E
249
+ )
250
+
251
+ z = self.relative_positional_embedding(res_idxs[:, 0], res_idxs)
252
+
253
+ z, _, _ = self.transformer.forward(
254
+ x=z,
255
+ sequence_id=knn_sequence_id,
256
+ affine=affine,
257
+ affine_mask=knn_affine_mask,
258
+ chain_id=knn_chain_id,
259
+ )
260
+
261
+ # Unflatten the output and take the query node embedding, which will always be the first one because
262
+ # a node has distance 0 with itself and the KNN are sorted.
263
+ z = z.view(B, L, E, -1)
264
+ z = z[:, :, 0, :]
265
+
266
+ return z
267
+
268
+ @staticmethod
269
+ def find_knn_edges(
270
+ coords,
271
+ padding_mask,
272
+ coord_mask,
273
+ sequence_id: torch.Tensor | None = None,
274
+ knn: int | None = None,
275
+ ) -> tuple:
276
+ assert knn is not None, "Must specify a non-null knn to find_knn_edges"
277
+ # Coords are N, CA, C
278
+ coords = coords.clone()
279
+ coords[~coord_mask] = 0
280
+
281
+ if sequence_id is None:
282
+ sequence_id = torch.zeros(
283
+ (coords.shape[0], coords.shape[1]), device=coords.device
284
+ ).long()
285
+
286
+ with torch.no_grad(), torch.cuda.amp.autocast(enabled=False): # type: ignore
287
+ ca = coords[..., 1, :]
288
+ edges, edge_mask = knn_graph(
289
+ ca, coord_mask, padding_mask, sequence_id, no_knn=knn
290
+ )
291
+
292
+ return edges, edge_mask
293
+
294
+ def encode(
295
+ self,
296
+ coords: torch.Tensor,
297
+ attention_mask: torch.Tensor | None = None,
298
+ sequence_id: torch.Tensor | None = None,
299
+ residue_index: torch.Tensor | None = None,
300
+ ):
301
+ coords = coords[..., :3, :]
302
+ affine, affine_mask = build_affine3d_from_coordinates(coords=coords)
303
+
304
+ if attention_mask is None:
305
+ attention_mask = torch.ones_like(affine_mask, dtype=torch.bool)
306
+ attention_mask = attention_mask.bool()
307
+
308
+ if sequence_id is None:
309
+ sequence_id = torch.zeros_like(affine_mask, dtype=torch.int64)
310
+
311
+ z = self.encode_local_structure(
312
+ coords=coords,
313
+ affine=affine,
314
+ attention_mask=attention_mask,
315
+ sequence_id=sequence_id,
316
+ affine_mask=affine_mask,
317
+ residue_index=residue_index,
318
+ )
319
+
320
+ z = z.masked_fill(~affine_mask.unsqueeze(2), 0)
321
+ z = self.pre_vq_proj(z)
322
+
323
+ z_q, min_encoding_indices, _ = self.codebook(z)
324
+
325
+ return z_q, min_encoding_indices
326
+
327
+
328
+ class StructureTokenDecoder(nn.Module):
329
+ def __init__(self, d_model, n_heads, n_layers):
330
+ super().__init__()
331
+ self.decoder_channels = d_model
332
+
333
+ self.vqvae_codebook_size = C.VQVAE_CODEBOOK_SIZE
334
+ self.special_tokens = C.VQVAE_SPECIAL_TOKENS
335
+ self.max_pae_bin = C.VQVAE_MAX_PAE_BIN
336
+
337
+ self.embed = nn.Embedding(
338
+ self.vqvae_codebook_size + len(self.special_tokens), d_model
339
+ )
340
+ self.decoder_stack = TransformerStack(
341
+ d_model, n_heads, 1, n_layers, scale_residue=False, n_layers_geom=0
342
+ )
343
+
344
+ self.affine_output_projection = Dim6RotStructureHead(
345
+ self.decoder_channels, 10, predict_torsion_angles=False
346
+ )
347
+
348
+ direction_loss_bins = C.VQVAE_DIRECTION_LOSS_BINS
349
+ pae_bins = C.VQVAE_PAE_BINS
350
+ self.pairwise_bins = [
351
+ 64, # distogram
352
+ direction_loss_bins * 6, # direction bins
353
+ pae_bins, # predicted aligned error
354
+ ]
355
+ self.pairwise_classification_head = PairwisePredictionHead(
356
+ self.decoder_channels,
357
+ downproject_dim=128,
358
+ hidden_dim=128,
359
+ n_bins=sum(self.pairwise_bins),
360
+ bias=False,
361
+ )
362
+
363
+ plddt_bins = C.VQVAE_PLDDT_BINS
364
+ self.plddt_head = RegressionHead(
365
+ embed_dim=self.decoder_channels, output_dim=plddt_bins
366
+ )
367
+
368
+ def decode(
369
+ self,
370
+ structure_tokens: torch.Tensor,
371
+ attention_mask: torch.Tensor | None = None,
372
+ sequence_id: torch.Tensor | None = None,
373
+ ):
374
+ if attention_mask is None:
375
+ attention_mask = torch.ones_like(structure_tokens, dtype=torch.bool)
376
+
377
+ attention_mask = attention_mask.bool()
378
+ if sequence_id is None:
379
+ sequence_id = torch.zeros_like(structure_tokens, dtype=torch.int64)
380
+ # not supported for now
381
+ chain_id = torch.zeros_like(structure_tokens, dtype=torch.int64)
382
+
383
+ # check that BOS and EOS are set correctly
384
+ assert (
385
+ structure_tokens[:, 0].eq(self.special_tokens["BOS"]).all()
386
+ ), "First token in structure_tokens must be BOS token"
387
+ assert (
388
+ structure_tokens[
389
+ torch.arange(structure_tokens.shape[0]), attention_mask.sum(1) - 1
390
+ ]
391
+ .eq(self.special_tokens["EOS"])
392
+ .all()
393
+ ), "Last token in structure_tokens must be EOS token"
394
+ assert (
395
+ (structure_tokens < 0).sum() == 0
396
+ ), "All structure tokens set to -1 should be replaced with BOS, EOS, PAD, or MASK tokens by now, but that isn't the case!"
397
+
398
+ x = self.embed(structure_tokens)
399
+ # !!! NOTE: Attention mask is actually unused here so watch out
400
+ x, _, _ = self.decoder_stack.forward(
401
+ x, affine=None, affine_mask=None, sequence_id=sequence_id, chain_id=chain_id
402
+ )
403
+
404
+ tensor7_affine, bb_pred = self.affine_output_projection(
405
+ x, affine=None, affine_mask=torch.zeros_like(attention_mask)
406
+ )
407
+
408
+ pae, ptm = None, None
409
+ pairwise_logits = self.pairwise_classification_head(x)
410
+ _, _, pae_logits = [
411
+ (o if o.numel() > 0 else None)
412
+ for o in pairwise_logits.split(self.pairwise_bins, dim=-1)
413
+ ]
414
+
415
+ special_tokens_mask = structure_tokens >= min(self.special_tokens.values())
416
+ pae = compute_predicted_aligned_error(
417
+ pae_logits, # type: ignore
418
+ aa_mask=~special_tokens_mask,
419
+ sequence_id=sequence_id,
420
+ max_bin=self.max_pae_bin,
421
+ )
422
+ # This might be broken for chainbreak tokens? We might align to the chainbreak
423
+ ptm = compute_tm(
424
+ pae_logits, # type: ignore
425
+ aa_mask=~special_tokens_mask,
426
+ max_bin=self.max_pae_bin,
427
+ )
428
+
429
+ plddt_logits = self.plddt_head(x)
430
+ plddt_value = CategoricalMixture(
431
+ plddt_logits, bins=plddt_logits.shape[-1]
432
+ ).mean()
433
+
434
+ return dict(
435
+ tensor7_affine=tensor7_affine,
436
+ bb_pred=bb_pred,
437
+ plddt=plddt_value,
438
+ ptm=ptm,
439
+ predicted_aligned_error=pae,
440
+ )
Dyna-1/esm/pretrained.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from esm.models.esm3 import ESM3
7
+ from esm.models.esmc import ESMC
8
+ from esm.models.function_decoder import FunctionTokenDecoder
9
+ from esm.models.vqvae import (
10
+ StructureTokenDecoder,
11
+ StructureTokenEncoder,
12
+ )
13
+ from esm.tokenization import (
14
+ get_esm3_model_tokenizers,
15
+ get_esmc_model_tokenizers,
16
+ )
17
+ from esm.utils.constants.esm3 import data_root
18
+ from esm.utils.constants.models import (
19
+ ESM3_FUNCTION_DECODER_V0,
20
+ ESM3_OPEN_SMALL,
21
+ ESM3_STRUCTURE_DECODER_V0,
22
+ ESM3_STRUCTURE_ENCODER_V0,
23
+ ESMC_300M,
24
+ ESMC_600M,
25
+ )
26
+
27
+ ModelBuilder = Callable[[torch.device | str], nn.Module]
28
+
29
+
30
+ def ESM3_structure_encoder_v0(device: torch.device | str = "cpu"):
31
+ with torch.device(device):
32
+ model = StructureTokenEncoder(
33
+ d_model=1024, n_heads=1, v_heads=128, n_layers=2, d_out=128, n_codes=4096
34
+ ).eval()
35
+ state_dict = torch.load(
36
+ data_root("esm3") / "data/weights/esm3_structure_encoder_v0.pth",
37
+ map_location=device,
38
+ )
39
+ model.load_state_dict(state_dict)
40
+ return model
41
+
42
+
43
+ def ESM3_structure_decoder_v0(device: torch.device | str = "cpu"):
44
+ with torch.device(device):
45
+ model = StructureTokenDecoder(d_model=1280, n_heads=20, n_layers=30).eval()
46
+ state_dict = torch.load(
47
+ data_root("esm3") / "data/weights/esm3_structure_decoder_v0.pth",
48
+ map_location=device,
49
+ )
50
+ model.load_state_dict(state_dict)
51
+ return model
52
+
53
+
54
+ def ESM3_function_decoder_v0(device: torch.device | str = "cpu"):
55
+ with torch.device(device):
56
+ model = FunctionTokenDecoder().eval()
57
+ state_dict = torch.load(
58
+ data_root("esm3") / "data/weights/esm3_function_decoder_v0.pth",
59
+ map_location=device,
60
+ )
61
+ model.load_state_dict(state_dict)
62
+ return model
63
+
64
+
65
+ def ESMC_300M_202412(device: torch.device | str = "cpu"):
66
+ with torch.device(device):
67
+ model = ESMC(
68
+ d_model=960, n_heads=15, n_layers=30, tokenizer=get_esmc_model_tokenizers()
69
+ ).eval()
70
+ state_dict = torch.load(
71
+ data_root("esmc-300") / "data/weights/esmc_300m_2024_12_v0.pth",
72
+ map_location=device,
73
+ )
74
+ model.load_state_dict(state_dict)
75
+
76
+ return model
77
+
78
+
79
+ def ESMC_600M_202412(device: torch.device | str = "cpu"):
80
+ with torch.device(device):
81
+ model = ESMC(
82
+ d_model=1152, n_heads=18, n_layers=36, tokenizer=get_esmc_model_tokenizers()
83
+ ).eval()
84
+ state_dict = torch.load(
85
+ data_root("esmc-600") / "data/weights/esmc_600m_2024_12_v0.pth",
86
+ map_location=device,
87
+ )
88
+ model.load_state_dict(state_dict)
89
+
90
+ return model
91
+
92
+
93
+ def ESM3_sm_open_v0(device: torch.device | str = "cpu"):
94
+ with torch.device(device):
95
+ model = ESM3(
96
+ d_model=1536,
97
+ n_heads=24,
98
+ v_heads=256,
99
+ n_layers=48,
100
+ structure_encoder_fn=ESM3_structure_encoder_v0,
101
+ structure_decoder_fn=ESM3_structure_decoder_v0,
102
+ function_decoder_fn=ESM3_function_decoder_v0,
103
+ tokenizers=get_esm3_model_tokenizers(ESM3_OPEN_SMALL),
104
+ ).eval()
105
+ state_dict = torch.load(
106
+ data_root("esm3") / "data/weights/esm3_sm_open_v1.pth", map_location=device
107
+ )
108
+ model.load_state_dict(state_dict)
109
+ return model
110
+
111
+
112
+ LOCAL_MODEL_REGISTRY: dict[str, ModelBuilder] = {
113
+ ESM3_OPEN_SMALL: ESM3_sm_open_v0,
114
+ ESM3_STRUCTURE_ENCODER_V0: ESM3_structure_encoder_v0,
115
+ ESM3_STRUCTURE_DECODER_V0: ESM3_structure_decoder_v0,
116
+ ESM3_FUNCTION_DECODER_V0: ESM3_function_decoder_v0,
117
+ ESMC_600M: ESMC_600M_202412,
118
+ ESMC_300M: ESMC_300M_202412,
119
+ }
120
+
121
+
122
+ def load_local_model(
123
+ model_name: str, device: torch.device = torch.device("cpu")
124
+ ) -> nn.Module:
125
+ if model_name not in LOCAL_MODEL_REGISTRY:
126
+ raise ValueError(f"Model {model_name} not found in local model registry.")
127
+ return LOCAL_MODEL_REGISTRY[model_name](device)
128
+
129
+
130
+ # Register custom versions of ESM3 for use with the local inference API
131
+ def register_local_model(model_name: str, model_builder: ModelBuilder) -> None:
132
+ LOCAL_MODEL_REGISTRY[model_name] = model_builder
Dyna-1/esm/sdk/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from esm.sdk.forge import ESM3ForgeInferenceClient
4
+
5
+ # Note: please do not import ESM3SageMakerClient here since that requires AWS SDK.
6
+
7
+
8
+ def client(
9
+ model="esm3-sm-open-v1",
10
+ url="https://forge.evolutionaryscale.ai",
11
+ token=os.environ.get("ESM_API_KEY", ""),
12
+ request_timeout=None,
13
+ ):
14
+ """
15
+ Args:
16
+ model: Name of the model to use.
17
+ url: URL of a forge server.
18
+ token: User's API token.
19
+ request_timeout: Amount of time to wait for a request to finish.
20
+ Default is wait indefinitely.
21
+ """
22
+ return ESM3ForgeInferenceClient(model, url, token, request_timeout)
Dyna-1/esm/sdk/api.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC
4
+ from typing import List, Sequence
5
+
6
+ import attr
7
+ import torch
8
+ from attr import asdict, define
9
+
10
+ import esm.utils.constants.api as C
11
+ from esm.tokenization import (
12
+ TokenizerCollectionProtocol,
13
+ get_esm3_model_tokenizers,
14
+ )
15
+ from esm.utils import encoding
16
+ from esm.utils.constants.models import ESM3_OPEN_SMALL
17
+ from esm.utils.misc import (
18
+ get_chainbreak_boundaries_from_sequence,
19
+ )
20
+ from esm.utils.structure.protein_chain import ProteinChain
21
+ from esm.utils.structure.protein_complex import ProteinComplex
22
+ from esm.utils.types import FunctionAnnotation, PathOrBuffer
23
+
24
+
25
+ class ProteinType(ABC): ...
26
+
27
+
28
+ ## Basic Types
29
+ @define
30
+ class ESMProtein(ProteinType):
31
+ # Tracks
32
+ sequence: str | None = None
33
+ secondary_structure: str | None = None
34
+ sasa: list[float | None] | None = None
35
+ function_annotations: list[FunctionAnnotation] | None = None
36
+ coordinates: torch.Tensor | None = None
37
+
38
+ # Metrics
39
+ plddt: torch.Tensor | None = None
40
+ ptm: torch.Tensor | None = None
41
+
42
+
43
+ # When calling EvolutionaryScale API, use this flag to disclose any
44
+ # sequences that may potentially have concerns.
45
+ # Such sequences may not go through standard safety filter for approved users.
46
+ # Reach out if interested in using this.
47
+ potential_sequence_of_concern: bool = False
48
+
49
+ def __len__(self):
50
+ if self.sequence is not None:
51
+ return len(self.sequence)
52
+ elif self.secondary_structure is not None:
53
+ return len(self.secondary_structure)
54
+ elif self.sasa is not None:
55
+ return len(self.sasa)
56
+ elif self.coordinates is not None:
57
+ return self.coordinates.size(0)
58
+ else:
59
+ raise ValueError("No track to determine length from.")
60
+
61
+ @classmethod
62
+ def from_pdb(
63
+ cls,
64
+ path: PathOrBuffer,
65
+ chain_id: str = "detect",
66
+ id: str | None = None,
67
+ is_predicted: bool = False,
68
+ ) -> ESMProtein:
69
+ protein_chain = ProteinChain.from_pdb(
70
+ path=path, chain_id=chain_id, id=id, is_predicted=is_predicted
71
+ )
72
+ return cls.from_protein_chain(protein_chain)
73
+
74
+ @classmethod
75
+ def from_protein_chain(
76
+ cls, protein_chain: ProteinChain, with_annotations: bool = False
77
+ ) -> ESMProtein:
78
+ # By default, we don't annotate with DSSP / SASA, which are expensive.
79
+ # If mkdssp is installed, we can annotate with a flag.
80
+ if with_annotations:
81
+ return ESMProtein(
82
+ sequence=protein_chain.sequence,
83
+ secondary_structure=protein_chain.dssp().tolist(),
84
+ sasa=protein_chain.sasa().tolist(),
85
+ function_annotations=None,
86
+ coordinates=torch.tensor(protein_chain.atom37_positions),
87
+ )
88
+ else:
89
+ return ESMProtein(
90
+ sequence=protein_chain.sequence,
91
+ secondary_structure=None,
92
+ sasa=None,
93
+ function_annotations=None,
94
+ coordinates=torch.tensor(protein_chain.atom37_positions),
95
+ )
96
+
97
+ @classmethod
98
+ def from_protein_complex(
99
+ cls, protein_complex: ProteinComplex, with_annotations: bool = False
100
+ ) -> ESMProtein:
101
+ if with_annotations:
102
+ raise NotImplementedError(
103
+ "Annotations are not supported for ProteinComplex yet."
104
+ )
105
+
106
+ return ESMProtein(
107
+ sequence=protein_complex.sequence,
108
+ secondary_structure=None,
109
+ sasa=None,
110
+ function_annotations=None,
111
+ coordinates=torch.tensor(protein_complex.atom37_positions),
112
+ )
113
+
114
+ def to_pdb(self, pdb_path: PathOrBuffer) -> None:
115
+ # Note: Will work for single chains as well and produce same pdb file
116
+ protein_complex = self.to_protein_complex().infer_oxygen()
117
+ protein_complex.to_pdb(pdb_path)
118
+
119
+ def to_pdb_string(self) -> str:
120
+ protein_chain = self.to_protein_chain()
121
+ return protein_chain.to_pdb_string()
122
+
123
+ def to_protein_chain(self) -> ProteinChain:
124
+ if self.coordinates is None:
125
+ raise ValueError("Coordinates are required to convert to a ProteinChain.")
126
+ protein_chain = ProteinChain.from_atom37(
127
+ atom37_positions=self.coordinates.to("cpu").numpy(),
128
+ id=None,
129
+ sequence=None if self.sequence is None else self.sequence.replace("_", "X"),
130
+ chain_id=None,
131
+ entity_id=None,
132
+ residue_index=None,
133
+ insertion_code=None,
134
+ confidence=None
135
+ if self.plddt is None
136
+ else self.plddt.detach().cpu().numpy(),
137
+ )
138
+ return protein_chain
139
+
140
+ def to_protein_complex(
141
+ self, copy_annotations_from_ground_truth: ProteinComplex | None = None
142
+ ) -> ProteinComplex:
143
+ assert (
144
+ self.sequence is not None
145
+ ), "ESMProtein must have a sequence to convert to ProteinComplex"
146
+ assert (
147
+ self.coordinates is not None
148
+ ), "ESMProtein must have coordinates to convert to ProteinComplex"
149
+ coords = self.coordinates.to("cpu").numpy()
150
+
151
+ chain_boundaries = get_chainbreak_boundaries_from_sequence(self.sequence)
152
+ if copy_annotations_from_ground_truth is not None:
153
+ gt_chains = list(copy_annotations_from_ground_truth.chain_iter())
154
+ else:
155
+ gt_chains = None
156
+ pred_chains = []
157
+ for i, (start, end) in enumerate(chain_boundaries):
158
+ pred_chain = ProteinChain.from_atom37(
159
+ atom37_positions=coords[start:end],
160
+ sequence=self.sequence[start:end],
161
+ chain_id=gt_chains[i].chain_id if gt_chains is not None else None,
162
+ entity_id=gt_chains[i].entity_id if gt_chains is not None else None,
163
+ )
164
+ pred_chains.append(pred_chain)
165
+ return ProteinComplex.from_chains(pred_chains)
166
+
167
+
168
+ @define
169
+ class ESMProteinTensor(ProteinType):
170
+ sequence: torch.Tensor | None = None
171
+ structure: torch.Tensor | None = None
172
+ secondary_structure: torch.Tensor | None = None
173
+ sasa: torch.Tensor | None = None
174
+ function: torch.Tensor | None = None
175
+ residue_annotations: torch.Tensor | None = None
176
+ coordinates: torch.Tensor | None = None
177
+
178
+ # When calling EvolutionaryScale API, use this flag to disclose any
179
+ # sequences that may potentially have concerns.
180
+ # Such sequences may not go through standard safety filter for approved users.
181
+ # Reach out if interested in using this.
182
+ potential_sequence_of_concern: bool = False
183
+ # Control vectors are vectors added to each layer of the model to nudge hidden states to the desired direction.
184
+ # len(control_vectors) == number of blocks in the model. Each vector in the list have the shape of (batch size, sequence length, hidden dim)
185
+ # so it can be added to the corresponding layer in the model
186
+
187
+ def _detect_attribute(self, func, msg):
188
+ mapped = {
189
+ k: func(k, v)
190
+ for k, v in asdict(self).items()
191
+ if isinstance(v, torch.Tensor)
192
+ }
193
+ s = set(mapped.values())
194
+ if len(s) <= 0:
195
+ return None
196
+ if len(s) != 1:
197
+ raise ValueError(f"Either no tracks or inconsistent {msg}: {mapped}")
198
+ return next(iter(s))
199
+
200
+ def __len__(self) -> int:
201
+ l = self._detect_attribute(lambda _, x: x.size(0), "length")
202
+ return l if l is not None else 0
203
+
204
+ @property
205
+ def device(self) -> str | torch.device:
206
+ d = self._detect_attribute(lambda _, x: x.device, "device")
207
+ assert d is not None
208
+ return d
209
+
210
+ def to(self, device_or_dtype: str | torch.device | torch.dtype) -> ESMProteinTensor:
211
+ def _to(name):
212
+ v = getattr(self, name)
213
+ if v is not None and isinstance(v, torch.Tensor):
214
+ setattr(self, name, v.to(device_or_dtype))
215
+
216
+ for n in attr.fields(ESMProteinTensor):
217
+ _to(n.name)
218
+
219
+ return self
220
+
221
+ @classmethod
222
+ def empty(
223
+ cls,
224
+ length: int,
225
+ tokenizers: TokenizerCollectionProtocol | None = None,
226
+ device: torch.device | str = "cpu",
227
+ ) -> ESMProteinTensor:
228
+ if tokenizers is None:
229
+ tokenizers = get_esm3_model_tokenizers(ESM3_OPEN_SMALL)
230
+
231
+ return ESMProteinTensor(
232
+ sequence=encoding.get_default_sequence_tokens(
233
+ length, tokenizers.sequence
234
+ ).to(device),
235
+ structure=encoding.get_default_structure_tokens(
236
+ length, tokenizers.structure
237
+ ).to(device),
238
+ secondary_structure=encoding.get_default_secondary_structure_tokens(
239
+ length, tokenizers.secondary_structure
240
+ ).to(device),
241
+ sasa=encoding.get_default_sasa_tokens(length, tokenizers.sasa).to(device),
242
+ function=encoding.get_default_function_tokens(
243
+ length, tokenizers.function
244
+ ).to(device),
245
+ residue_annotations=encoding.get_default_residue_annotation_tokens(
246
+ length, tokenizers.residue_annotations
247
+ ).to(device),
248
+ )
249
+
250
+
251
+ @define
252
+ class ESMProteinError(Exception, ProteinType):
253
+ error_code: int # Error code follows HTTP convention, i.e., 404 NotFoundError, 500 InternalError.
254
+ error_msg: str
255
+
256
+
257
+ ## High Level Endpoint Types
258
+ @define
259
+ class GenerationConfig:
260
+ track: str = ""
261
+ invalid_ids: Sequence[int] = []
262
+ # Controls the number of tokens to unmask during each round of iterative generation.
263
+ schedule: str = attr.field(
264
+ validator=attr.validators.in_(["cosine", "linear"]), default="cosine"
265
+ )
266
+ # Controls which tokens to unmask during each round of iterative generation.
267
+ # "random" will unmask a correct number of tokens randomly.
268
+ # "entropy" will unmask the tokens with the lowest logit entropy first.
269
+ strategy: str = attr.field(
270
+ validator=attr.validators.in_(["random", "entropy"]), default="entropy"
271
+ )
272
+ # Set this to a higher value for better generation results.
273
+ # Note that this needs to be less than or equal to the sequence length.
274
+ num_steps: int = 1
275
+ temperature: float = 1.0
276
+ temperature_annealing: bool = False
277
+ top_p: float = 1.0
278
+ condition_on_coordinates_only: bool = True
279
+
280
+ def use_entropy_based_unmasking_strategy(self):
281
+ """Use entropy based unmasking strategy during generation."""
282
+ self.schedule = "cosine"
283
+ self.strategy = "entropy"
284
+ self.temperature_annealing = False
285
+
286
+ def use_generative_unmasking_strategy(self):
287
+ """Use an unmasking strategy that produces more variety of generations."""
288
+ self.schedule = "cosine"
289
+ self.strategy = "random"
290
+ self.temperature_annealing = True
291
+
292
+
293
+ @define
294
+ class InverseFoldingConfig:
295
+ invalid_ids: Sequence[int] = []
296
+ temperature: float = 1.0
297
+
298
+
299
+ ## Low Level Endpoint Types
300
+ @define
301
+ class SamplingTrackConfig:
302
+ temperature: float = 1.0
303
+ top_p: float = 1.0
304
+ only_sample_masked_tokens: bool = True
305
+ invalid_ids: Sequence[int] = []
306
+ topk_logprobs: int = 0
307
+
308
+
309
+ @define
310
+ class SamplingConfig:
311
+ sequence: SamplingTrackConfig | None = attr.field(
312
+ default=None, metadata={"max_topk": C.MAX_TOPK_SEQUENCE}
313
+ )
314
+ structure: SamplingTrackConfig | None = attr.field(
315
+ default=None, metadata={"max_topk": C.MAX_TOPK_STRUCTURE}
316
+ )
317
+ secondary_structure: SamplingTrackConfig | None = attr.field(
318
+ default=None, metadata={"max_topk": C.MAX_TOPK_SECONDARY_STRUCTURE}
319
+ )
320
+ sasa: SamplingTrackConfig | None = attr.field(
321
+ default=None, metadata={"max_topk": C.MAX_TOPK_SASA}
322
+ )
323
+ function: SamplingTrackConfig | None = attr.field(
324
+ default=None, metadata={"max_topk": C.MAX_TOPK_FUNCTION}
325
+ )
326
+
327
+ return_per_residue_embeddings: bool = False
328
+ return_mean_embedding: bool = False
329
+
330
+
331
+ @define
332
+ class ForwardTrackData:
333
+ sequence: torch.Tensor | None = None
334
+ structure: torch.Tensor | None = None
335
+ secondary_structure: torch.Tensor | None = None
336
+ sasa: torch.Tensor | None = None
337
+ function: torch.Tensor | None = None
338
+
339
+
340
+ @define
341
+ class LogitsConfig:
342
+ # Logits.
343
+ sequence: bool = False
344
+ structure: bool = False
345
+ secondary_structure: bool = False
346
+ sasa: bool = False
347
+ function: bool = False
348
+ residue_annotations: bool = False
349
+
350
+ # Embeddings.
351
+ return_embeddings: bool = False
352
+
353
+
354
+ @define
355
+ class LogitsOutput:
356
+ logits: ForwardTrackData | None = None
357
+ embeddings: torch.Tensor | None = None
358
+
359
+ # Residue annotations is multi-hot, so deserves special treatment
360
+ # It's not a categorical distribution, but instead a bernoulli, so
361
+ # softmax across the last dimension is _wrong_
362
+ residue_annotation_logits: torch.Tensor | None = None
363
+
364
+
365
+ @define
366
+ class ForwardAndSampleOutput(LogitsOutput):
367
+ protein_tensor: ESMProteinTensor = ESMProteinTensor()
368
+
369
+ entropy: ForwardTrackData | None = None
370
+ # Probability of sampled token
371
+ prob: ForwardTrackData | None = None
372
+ logprob: ForwardTrackData | None = None
373
+ # Top probability at this position
374
+ top_prob: ForwardTrackData | None = None
375
+ topk_logprob: ForwardTrackData | None = None
376
+ # Which tokens correspond to top probability
377
+ topk_tokens: ForwardTrackData | None = None
378
+ per_residue_embedding: torch.Tensor | None = None
379
+ mean_embedding: torch.Tensor | None = None
380
+
381
+
382
+ class ESM3InferenceClient(ABC):
383
+ def generate(self, input: ProteinType, config: GenerationConfig) -> ProteinType:
384
+ # This is the easiest and most flexible way to run ESM3. Generate will
385
+ # iteratively sample tokens an provide an output with the track specified
386
+ # completely filled out, according to the GenerationConfig provided.
387
+ # It is a local function wrapping calls for encode -> iterative_sampling -> decode.
388
+ # if a ESMProteinTensor is provided, encode and decode are skipped
389
+ raise NotImplementedError
390
+
391
+ def batch_generate(
392
+ self, inputs: Sequence[ProteinType], configs: Sequence[GenerationConfig]
393
+ ) -> Sequence[ProteinType]:
394
+ # Same as generate(...), but generates a batch of proteins at once.
395
+ raise NotImplementedError
396
+
397
+ def encode(self, input: ESMProtein) -> ESMProteinTensor:
398
+ # Encode allows for encoding RawRepresentation into TokenizedRepresentation.
399
+ # This runs the structure_token_encoder, as well as dealing with PDB => atom37 conversion
400
+ raise NotImplementedError
401
+
402
+ def decode(self, input: ESMProteinTensor) -> ESMProtein:
403
+ # Decode is the inverse of encode, and runs a structure_token_decoder to output coordinates
404
+ raise NotImplementedError
405
+
406
+ def logits(
407
+ self, input: ESMProteinTensor, config: LogitsConfig = LogitsConfig()
408
+ ) -> LogitsOutput:
409
+ # Our API generally discourages using raw forwards.
410
+ # This is because sending logits can be prohibitively expensive.
411
+ # Please use forward_and_sample instead.
412
+ raise NotImplementedError
413
+
414
+ def forward_and_sample(
415
+ self, input: ESMProteinTensor, sampling_configuration: SamplingConfig
416
+ ) -> ForwardAndSampleOutput:
417
+ # forward_and_sample runs a single model forward, sampling tokens according to `SamplingConfiguration`.
418
+ # This is the way for power users to run ESM3. We hope to design this in a way to enable high throughput
419
+ # inference, as well as arbitrary chain-of-though invocations of ESM3.
420
+ raise NotImplementedError
421
+
422
+ @property
423
+ def raw_model(self):
424
+ # Get underlying esm3 model of an inference client.
425
+ raise NotImplementedError
426
+
427
+
428
+ class ESMCInferenceClient(ABC):
429
+ def encode(self, input: ESMProtein) -> ESMProteinTensor:
430
+ # Encode allows for encoding RawRepresentation into TokenizedRepresentation.
431
+ raise NotImplementedError
432
+
433
+ def decode(self, input: ESMProteinTensor) -> ESMProtein:
434
+ # Decode is the inverse of encode
435
+ raise NotImplementedError
436
+
437
+ def logits(
438
+ self, input: ESMProteinTensor, config: LogitsConfig = LogitsConfig()
439
+ ) -> LogitsOutput:
440
+ raise NotImplementedError
441
+
442
+ @property
443
+ def raw_model(self):
444
+ # Get underlying esmc model of an inference client.
445
+ raise NotImplementedError
Dyna-1/esm/sdk/forge.py ADDED
@@ -0,0 +1,580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from functools import wraps
3
+ from typing import Sequence
4
+ from urllib.parse import urljoin
5
+
6
+ import requests
7
+ import torch
8
+ from tenacity import retry, retry_if_result, stop_after_attempt, wait_exponential
9
+
10
+ from esm.sdk.api import (
11
+ ESM3InferenceClient,
12
+ ESMProtein,
13
+ ESMProteinError,
14
+ ESMProteinTensor,
15
+ ForwardAndSampleOutput,
16
+ ForwardTrackData,
17
+ GenerationConfig,
18
+ InverseFoldingConfig,
19
+ LogitsConfig,
20
+ LogitsOutput,
21
+ ProteinType,
22
+ SamplingConfig,
23
+ SamplingTrackConfig,
24
+ )
25
+ from esm.utils.misc import maybe_list, maybe_tensor
26
+ from esm.utils.sampling import validate_sampling_config
27
+ from esm.utils.types import FunctionAnnotation
28
+
29
+
30
+ def _list_to_function_annotations(l) -> list[FunctionAnnotation] | None:
31
+ if l is None or len(l) <= 0:
32
+ return None
33
+ return [FunctionAnnotation(*t) for t in l]
34
+
35
+
36
+ def retry_if_specific_error(exception):
37
+ """
38
+ We only retry on specific errors.
39
+ Currently we retry for 502 (bad gateway) and 429 (rate limit)
40
+ """
41
+ return isinstance(exception, ESMProteinError) and exception.error_code in {
42
+ 429,
43
+ 502,
44
+ 504,
45
+ }
46
+
47
+
48
+ def log_retry_attempt(retry_state):
49
+ print(
50
+ f"Retrying... Attempt {retry_state.attempt_number} after {retry_state.next_action.sleep}s due to: {retry_state.outcome.result()}"
51
+ )
52
+
53
+
54
+ def _validate_protein_tensor_input(input):
55
+ if not isinstance(input, ESMProteinTensor):
56
+ raise ValueError(
57
+ "Input must be an ESMProteinTensor instance. "
58
+ "Use encode() API to encode an ESMProtein into ESMProteinTensor."
59
+ )
60
+
61
+
62
+ class SequenceStructureForgeInferenceClient:
63
+ def __init__(
64
+ self,
65
+ url: str = "https://forge.evolutionaryscale.ai",
66
+ token: str = "",
67
+ request_timeout: int | None = None,
68
+ ):
69
+ if token == "":
70
+ raise RuntimeError(
71
+ "Please provide a token to connect to Forge via token=YOUR_API_TOKEN_HERE"
72
+ )
73
+ self.url = url
74
+ self.token = token
75
+ self.headers = {"Authorization": f"Bearer {self.token}"}
76
+ self.request_timeout = request_timeout
77
+
78
+ def fold(
79
+ self,
80
+ sequence: str,
81
+ potential_sequence_of_concern: bool,
82
+ model_name: str | None = None,
83
+ ) -> ESMProtein | ESMProteinError:
84
+ request = {"sequence": sequence}
85
+ if model_name is not None:
86
+ request["model"] = model_name
87
+ try:
88
+ data = self._post("fold", request, potential_sequence_of_concern)
89
+ except ESMProteinError as e:
90
+ return e
91
+
92
+ return ESMProtein(
93
+ coordinates=maybe_tensor(data["coordinates"], convert_none_to_nan=True)
94
+ )
95
+
96
+ def inverse_fold(
97
+ self,
98
+ coordinates: torch.Tensor,
99
+ config: InverseFoldingConfig,
100
+ potential_sequence_of_concern: bool,
101
+ model_name: str | None = None,
102
+ ) -> ESMProtein | ESMProteinError:
103
+ inverse_folding_config = {
104
+ "invalid_ids": config.invalid_ids,
105
+ "temperature": config.temperature,
106
+ }
107
+ request = {
108
+ "coordinates": maybe_list(coordinates, convert_nan_to_none=True),
109
+ "inverse_folding_config": inverse_folding_config,
110
+ }
111
+ if model_name is not None:
112
+ request["model"] = model_name
113
+ try:
114
+ data = self._post("inverse_fold", request, potential_sequence_of_concern)
115
+ except ESMProteinError as e:
116
+ return e
117
+
118
+ return ESMProtein(sequence=data["sequence"])
119
+
120
+ def _post(self, endpoint, request, potential_sequence_of_concern):
121
+ request["potential_sequence_of_concern"] = potential_sequence_of_concern
122
+
123
+ response = requests.post(
124
+ urljoin(self.url, f"/api/v1/{endpoint}"),
125
+ json=request,
126
+ headers=self.headers,
127
+ timeout=self.request_timeout,
128
+ )
129
+
130
+ if not response.ok:
131
+ raise ESMProteinError(
132
+ error_code=response.status_code,
133
+ error_msg=f"Failure in {endpoint}: {response.text}",
134
+ )
135
+
136
+ data = response.json()
137
+ # Nextjs puts outputs dict under "data" key.
138
+ # Lift it up for easier downstream processing.
139
+ if "outputs" not in data and "data" in data:
140
+ data = data["data"]
141
+
142
+ # Print warning message if there is any.
143
+ if "warning_messages" in data and data["warning_messages"] is not None:
144
+ for msg in data["warning_messages"]:
145
+ print("\033[31m", msg, "\033[0m")
146
+
147
+ return data
148
+
149
+
150
+ class ESM3ForgeInferenceClient(ESM3InferenceClient):
151
+ def __init__(
152
+ self,
153
+ model: str,
154
+ url: str = "https://forge.evolutionaryscale.ai",
155
+ token: str = "",
156
+ request_timeout: int | None = None,
157
+ min_retry_wait: int = 1,
158
+ max_retry_wait: int = 10,
159
+ max_retry_attempts: int = 5,
160
+ ):
161
+ if token == "":
162
+ raise RuntimeError(
163
+ "Please provide a token to connect to Forge via token=YOUR_API_TOKEN_HERE"
164
+ )
165
+ self.model = model # Name of the model to run.
166
+ self.url = url
167
+ self.token = token
168
+ self.headers = {"Authorization": f"Bearer {self.token}"}
169
+ self.request_timeout = request_timeout
170
+ self.min_retry_wait = min_retry_wait
171
+ self.max_retry_wait = max_retry_wait
172
+ self.max_retry_attempts = max_retry_attempts
173
+
174
+ @staticmethod
175
+ def retry_decorator(func):
176
+ """
177
+ A static method that returns a retry decorator. This decorator uses the
178
+ instance's retry settings.
179
+ """
180
+
181
+ @wraps(func)
182
+ def wrapper(instance, *args, **kwargs):
183
+ retry_decorator = retry(
184
+ retry=retry_if_result(retry_if_specific_error),
185
+ wait=wait_exponential(
186
+ multiplier=1,
187
+ min=instance.min_retry_wait,
188
+ max=instance.max_retry_wait,
189
+ ),
190
+ stop=stop_after_attempt(instance.max_retry_attempts),
191
+ before_sleep=log_retry_attempt,
192
+ )
193
+ # Apply the retry decorator to the function
194
+ return retry_decorator(func)(instance, *args, **kwargs)
195
+
196
+ return wrapper
197
+
198
+ @retry_decorator
199
+ def generate(self, input: ProteinType, config: GenerationConfig) -> ProteinType:
200
+ if isinstance(input, ESMProtein):
201
+ output = self.__generate_protein(input, config)
202
+ elif isinstance(input, ESMProteinTensor):
203
+ output = self.__generate_protein_tensor(input, config)
204
+ else:
205
+ return ESMProteinError(
206
+ error_code=500, error_msg=f"Unknown input type {type(input)}"
207
+ )
208
+
209
+ if (
210
+ isinstance(output, ESMProtein)
211
+ and isinstance(input, ESMProtein)
212
+ and config.track not in ["function", "residue_annotations"]
213
+ ):
214
+ # Function and residue annotation encoding/decoding is lossy
215
+ # There is no guarantee that decoding encoded tokens will yield the same input
216
+ output.function_annotations = input.function_annotations
217
+
218
+ return output
219
+
220
+ def batch_generate(
221
+ self, inputs: Sequence[ProteinType], configs: Sequence[GenerationConfig]
222
+ ) -> Sequence[ProteinType]:
223
+ """Forge supports auto-batching. So batch_generate() for the Forge client
224
+ is as simple as running a collection of generate() in parallel using asyncio.
225
+ """
226
+ loop = asyncio.get_event_loop()
227
+
228
+ async def _async_generate():
229
+ futures = [
230
+ loop.run_in_executor(None, self.generate, protein, config)
231
+ for protein, config in zip(inputs, configs)
232
+ ]
233
+ return await asyncio.gather(*futures, return_exceptions=True)
234
+
235
+ results = loop.run_until_complete(_async_generate())
236
+
237
+ def _capture_exception(r):
238
+ if isinstance(r, BaseException) and not isinstance(r, ESMProteinError):
239
+ return ESMProteinError(500, str(r))
240
+ return r
241
+
242
+ return [_capture_exception(r) for r in results]
243
+
244
+ def __generate_protein(
245
+ self, input: ESMProtein, config: GenerationConfig
246
+ ) -> ESMProtein | ESMProteinError:
247
+ req = {}
248
+ req["sequence"] = input.sequence
249
+ req["secondary_structure"] = input.secondary_structure
250
+ req["sasa"] = input.sasa
251
+ if input.function_annotations is not None:
252
+ req["function"] = [x.to_tuple() for x in input.function_annotations]
253
+ req["coordinates"] = maybe_list(input.coordinates, convert_nan_to_none=True)
254
+
255
+ request = {
256
+ "model": self.model,
257
+ "inputs": req,
258
+ "track": config.track,
259
+ "invalid_ids": config.invalid_ids,
260
+ "schedule": config.schedule,
261
+ "num_steps": config.num_steps,
262
+ "temperature": config.temperature,
263
+ "top_p": config.top_p,
264
+ "condition_on_coordinates_only": config.condition_on_coordinates_only,
265
+ }
266
+ try:
267
+ data = self._post("generate", request, input.potential_sequence_of_concern)
268
+ except ESMProteinError as e:
269
+ return e
270
+
271
+ return ESMProtein(
272
+ sequence=data["outputs"]["sequence"],
273
+ secondary_structure=data["outputs"]["secondary_structure"],
274
+ sasa=data["outputs"]["sasa"],
275
+ function_annotations=_list_to_function_annotations(
276
+ data["outputs"]["function"]
277
+ ),
278
+ coordinates=maybe_tensor(
279
+ data["outputs"]["coordinates"], convert_none_to_nan=True
280
+ ),
281
+ plddt=maybe_tensor(data["outputs"]["plddt"]),
282
+ ptm=maybe_tensor(data["outputs"]["ptm"]),
283
+ )
284
+
285
+ def __generate_protein_tensor(
286
+ self, input: ESMProteinTensor, config: GenerationConfig
287
+ ) -> ESMProteinTensor | ESMProteinError:
288
+ req = {}
289
+ req["sequence"] = maybe_list(input.sequence)
290
+ req["structure"] = maybe_list(input.structure)
291
+ req["secondary_structure"] = maybe_list(input.secondary_structure)
292
+ req["sasa"] = maybe_list(input.sasa)
293
+ req["function"] = maybe_list(input.function)
294
+ req["coordinates"] = maybe_list(input.coordinates, convert_nan_to_none=True)
295
+ req["residue_annotation"] = maybe_list(input.residue_annotations)
296
+
297
+ request = {
298
+ "model": self.model,
299
+ "inputs": req,
300
+ "track": config.track,
301
+ "invalid_ids": config.invalid_ids,
302
+ "schedule": config.schedule,
303
+ "num_steps": config.num_steps,
304
+ "temperature": config.temperature,
305
+ "top_p": config.top_p,
306
+ "condition_on_coordinates_only": config.condition_on_coordinates_only,
307
+ }
308
+
309
+ try:
310
+ data = self._post(
311
+ "generate_tensor", request, input.potential_sequence_of_concern
312
+ )
313
+ except ESMProteinError as e:
314
+ return e
315
+
316
+ def _field_to_tensor(field, convert_none_to_nan: bool = False):
317
+ if field not in data["outputs"]:
318
+ return None
319
+ return maybe_tensor(
320
+ data["outputs"][field], convert_none_to_nan=convert_none_to_nan
321
+ )
322
+
323
+ output = ESMProteinTensor(
324
+ sequence=_field_to_tensor("sequence"),
325
+ structure=_field_to_tensor("structure"),
326
+ secondary_structure=_field_to_tensor("secondary_structure"),
327
+ sasa=_field_to_tensor("sasa"),
328
+ function=_field_to_tensor("function"),
329
+ residue_annotations=_field_to_tensor("residue_annotation"),
330
+ coordinates=_field_to_tensor("coordinates", convert_none_to_nan=True),
331
+ )
332
+
333
+ return output
334
+
335
+ @retry_decorator
336
+ def forward_and_sample(
337
+ self, input: ESMProteinTensor, sampling_configuration: SamplingConfig
338
+ ) -> ForwardAndSampleOutput | ESMProteinError:
339
+ _validate_protein_tensor_input(input)
340
+ validate_sampling_config(sampling_configuration, on_invalid="raise")
341
+
342
+ req = {}
343
+ sampling_config = {}
344
+ embedding_config = {
345
+ "sequence": sampling_configuration.return_mean_embedding,
346
+ "per_residue": sampling_configuration.return_per_residue_embeddings,
347
+ }
348
+
349
+ req["sequence"] = maybe_list(input.sequence)
350
+ req["structure"] = maybe_list(input.structure)
351
+ req["secondary_structure"] = maybe_list(input.secondary_structure)
352
+ req["sasa"] = maybe_list(input.sasa)
353
+ req["function"] = maybe_list(input.function)
354
+ req["coordinates"] = maybe_list(input.coordinates, convert_nan_to_none=True)
355
+ req["residue_annotation"] = maybe_list(input.residue_annotations)
356
+
357
+ def do_track(t: str):
358
+ track: SamplingTrackConfig | None
359
+ if (track := getattr(sampling_configuration, t, None)) is None:
360
+ sampling_config[t] = None
361
+ else:
362
+ sampling_config[t] = {
363
+ "temperature": track.temperature,
364
+ "top_p": track.top_p,
365
+ "only_sample_masked_tokens": track.only_sample_masked_tokens,
366
+ "invalid_ids": track.invalid_ids,
367
+ "topk_logprobs": track.topk_logprobs,
368
+ }
369
+
370
+ do_track("sequence")
371
+ do_track("structure")
372
+ do_track("secondary_structure")
373
+ do_track("sasa")
374
+ do_track("function")
375
+
376
+ request = {
377
+ "model": self.model,
378
+ "inputs": req,
379
+ "sampling_config": sampling_config,
380
+ "embedding_config": embedding_config,
381
+ }
382
+ try:
383
+ data = self._post(
384
+ "forward_and_sample", request, input.potential_sequence_of_concern
385
+ )
386
+ except ESMProteinError as e:
387
+ return e
388
+
389
+ def get(k, field):
390
+ if data[k] is None:
391
+ return None
392
+ v = data[k][field]
393
+ return torch.tensor(v) if v is not None else None
394
+
395
+ tokens = ESMProteinTensor(
396
+ sequence=get("sequence", "tokens"),
397
+ structure=get("structure", "tokens"),
398
+ secondary_structure=get("secondary_structure", "tokens"),
399
+ sasa=get("sasa", "tokens"),
400
+ function=get("function", "tokens"),
401
+ )
402
+
403
+ def get_track(field):
404
+ return ForwardTrackData(
405
+ sequence=get("sequence", field),
406
+ structure=get("structure", field),
407
+ secondary_structure=get("secondary_structure", field),
408
+ sasa=get("sasa", field),
409
+ function=get("function", field),
410
+ )
411
+
412
+ def operate_on_track(track: ForwardTrackData, fn):
413
+ apply = lambda x: fn(x) if x is not None else None
414
+ return ForwardTrackData(
415
+ sequence=apply(track.sequence),
416
+ structure=apply(track.structure),
417
+ secondary_structure=apply(track.secondary_structure),
418
+ sasa=apply(track.sasa),
419
+ function=apply(track.function),
420
+ )
421
+
422
+ logprob = get_track("logprobs")
423
+ output = ForwardAndSampleOutput(
424
+ protein_tensor=tokens,
425
+ logprob=logprob,
426
+ prob=operate_on_track(logprob, torch.exp),
427
+ entropy=get_track("entropy"),
428
+ topk_logprob=get_track("topk_logprobs"),
429
+ topk_tokens=get_track("topk_tokens"),
430
+ per_residue_embedding=data["embeddings"]["per_residue"],
431
+ mean_embedding=data["embeddings"]["sequence"],
432
+ )
433
+ return output
434
+
435
+ @retry_decorator
436
+ def encode(self, input: ESMProtein) -> ESMProteinTensor | ESMProteinError:
437
+ tracks = {}
438
+ tracks["sequence"] = input.sequence
439
+ tracks["secondary_structure"] = input.secondary_structure
440
+ tracks["sasa"] = input.sasa
441
+ if input.function_annotations is not None:
442
+ tracks["function"] = [x.to_tuple() for x in input.function_annotations]
443
+ tracks["coordinates"] = maybe_list(input.coordinates, convert_nan_to_none=True)
444
+
445
+ request = {"inputs": tracks, "model": self.model}
446
+
447
+ try:
448
+ data = self._post("encode", request, input.potential_sequence_of_concern)
449
+ except ESMProteinError as e:
450
+ return e
451
+
452
+ return ESMProteinTensor(
453
+ sequence=maybe_tensor(data["outputs"]["sequence"]),
454
+ structure=maybe_tensor(data["outputs"]["structure"]),
455
+ coordinates=maybe_tensor(
456
+ data["outputs"]["coordinates"], convert_none_to_nan=True
457
+ ),
458
+ secondary_structure=maybe_tensor(data["outputs"]["secondary_structure"]),
459
+ sasa=maybe_tensor(data["outputs"]["sasa"]),
460
+ function=maybe_tensor(data["outputs"]["function"]),
461
+ residue_annotations=maybe_tensor(data["outputs"]["residue_annotation"]),
462
+ )
463
+
464
+ @retry_decorator
465
+ def decode(self, input: ESMProteinTensor) -> ESMProtein | ESMProteinError:
466
+ _validate_protein_tensor_input(input)
467
+
468
+ tokens = {}
469
+ tokens["sequence"] = maybe_list(input.sequence)
470
+ tokens["structure"] = maybe_list(input.structure)
471
+ tokens["secondary_structure"] = maybe_list(input.secondary_structure)
472
+ tokens["sasa"] = maybe_list(input.sasa)
473
+ tokens["function"] = maybe_list(input.function)
474
+ tokens["residue_annotation"] = maybe_list(input.residue_annotations)
475
+ tokens["coordinates"] = maybe_list(input.coordinates, convert_nan_to_none=True)
476
+
477
+ request = {"model": self.model, "inputs": tokens}
478
+
479
+ try:
480
+ data = self._post("decode", request, input.potential_sequence_of_concern)
481
+ except ESMProteinError as e:
482
+ return e
483
+
484
+ return ESMProtein(
485
+ sequence=data["outputs"]["sequence"],
486
+ secondary_structure=data["outputs"]["secondary_structure"],
487
+ sasa=data["outputs"]["sasa"],
488
+ function_annotations=_list_to_function_annotations(
489
+ data["outputs"]["function"]
490
+ ),
491
+ coordinates=maybe_tensor(
492
+ data["outputs"]["coordinates"], convert_none_to_nan=True
493
+ ),
494
+ plddt=maybe_tensor(data["outputs"]["plddt"]),
495
+ ptm=maybe_tensor(data["outputs"]["ptm"]),
496
+ )
497
+
498
+ @retry_decorator
499
+ def logits(
500
+ self, input: ESMProteinTensor, config: LogitsConfig = LogitsConfig()
501
+ ) -> LogitsOutput | ESMProteinError:
502
+ _validate_protein_tensor_input(input)
503
+
504
+ # Note: using raw model forwards is discouraged because of the byte size
505
+ # of the logits.
506
+ # Please use forward_and_sample instead.
507
+ req = {}
508
+ req["sequence"] = maybe_list(input.sequence)
509
+ req["structure"] = maybe_list(input.structure)
510
+ req["secondary_structure"] = maybe_list(input.secondary_structure)
511
+ req["sasa"] = maybe_list(input.sasa)
512
+ req["function"] = maybe_list(input.function)
513
+ req["coordinates"] = maybe_list(input.coordinates, convert_nan_to_none=True)
514
+ req["residue_annotation"] = maybe_list(input.residue_annotations)
515
+
516
+ logits_config = {
517
+ "sequence": config.sequence,
518
+ "structure": config.structure,
519
+ "secondary_structure": config.secondary_structure,
520
+ "sasa": config.sasa,
521
+ "function": config.function,
522
+ "residue_annotations": config.residue_annotations,
523
+ "return_embeddings": config.return_embeddings,
524
+ }
525
+
526
+ request = {"model": self.model, "inputs": req, "logits_config": logits_config}
527
+
528
+ try:
529
+ data = self._post("logits", request, input.potential_sequence_of_concern)
530
+ except ESMProteinError as e:
531
+ return e
532
+
533
+ def _maybe_logits(track: str):
534
+ if "logits" in data and track in data["logits"]:
535
+ return maybe_tensor(data["logits"][track])
536
+ return None
537
+
538
+ output = LogitsOutput(
539
+ logits=ForwardTrackData(
540
+ sequence=_maybe_logits("sequence"),
541
+ structure=_maybe_logits("structure"),
542
+ secondary_structure=_maybe_logits("secondary_structure"),
543
+ sasa=_maybe_logits("sasa"),
544
+ function=_maybe_logits("function"),
545
+ ),
546
+ embeddings=maybe_tensor(data["embeddings"]),
547
+ residue_annotation_logits=_maybe_logits("residue_annotation"),
548
+ )
549
+
550
+ return output
551
+
552
+ def _post(self, endpoint, request, potential_sequence_of_concern):
553
+ request["potential_sequence_of_concern"] = potential_sequence_of_concern
554
+
555
+ response = requests.post(
556
+ urljoin(self.url, f"/api/v1/{endpoint}"),
557
+ json=request,
558
+ headers=self.headers,
559
+ timeout=self.request_timeout,
560
+ )
561
+
562
+ if not response.ok:
563
+ raise ESMProteinError(
564
+ error_code=response.status_code,
565
+ error_msg=f"Failure in {endpoint}: {response.text}",
566
+ )
567
+
568
+ data = response.json()
569
+ # Nextjs puts outputs dict under "data" key.
570
+ # Lift it up for easier downstream processing.
571
+ if "outputs" not in data and "data" in data:
572
+ data = data["data"]
573
+
574
+ return data
575
+
576
+ @property
577
+ def raw_model(self):
578
+ raise NotImplementedError(
579
+ f"Can not get underlying remote model {self.model} from a Forge client."
580
+ )
Dyna-1/esm/sdk/sagemaker.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import boto3
4
+
5
+ from esm.sdk.forge import (
6
+ ESM3ForgeInferenceClient,
7
+ SequenceStructureForgeInferenceClient,
8
+ )
9
+
10
+
11
+ class SequenceStructureSageMakerClient(SequenceStructureForgeInferenceClient):
12
+ def __init__(self, endpoint_name: str):
13
+ """SequenceStructure (folding and inverse folding) client that talks to a SageMaker endpoint.
14
+
15
+ Args:
16
+ endpoint_name: Name of the SageMaker endpoint.
17
+ """
18
+ # Dummy URL and token to make SequenceStructureForgeInferenceClient happy.
19
+ super().__init__(url="", token="dummy")
20
+
21
+ self._endpoint_name = endpoint_name
22
+
23
+ self._client = boto3.client(service_name="sagemaker-runtime")
24
+
25
+ def _post(self, endpoint, request, potential_sequence_of_concern):
26
+ request["potential_sequence_of_concern"] = potential_sequence_of_concern
27
+ request["model"] = request.get("model", None)
28
+ invocations_request = {
29
+ # Duplicate these fields at the top level to make Forge requests consistent.
30
+ "model": request["model"],
31
+ "request_id": "", # Forge specific field.
32
+ "user_id": "", # Forge specific field.
33
+ # Invocation data bits.
34
+ "api_ver": "v1", # Must be v1 right now.
35
+ "endpoint": endpoint,
36
+ # Wrapped request.
37
+ endpoint: request,
38
+ }
39
+
40
+ try:
41
+ response = self._client.invoke_endpoint(
42
+ EndpointName=self._endpoint_name,
43
+ ContentType="application/json",
44
+ Body=json.dumps(invocations_request),
45
+ )
46
+ except Exception as e:
47
+ raise RuntimeError(f"Failure in {endpoint}: {e}") from e
48
+
49
+ data = json.loads(response["Body"].read().decode())
50
+
51
+ # Response must match request.
52
+ assert (
53
+ data["endpoint"] == endpoint
54
+ ), f"Response endpoint is {data['endpoint']} but request is {endpoint}"
55
+
56
+ # Get the actual responses under the endpoint key.
57
+ data = data[endpoint]
58
+
59
+ return data
60
+
61
+
62
+ class ESM3SageMakerClient(ESM3ForgeInferenceClient):
63
+ def __init__(self, endpoint_name: str, model: str):
64
+ """ESM3 client that talks to a SageMaker endpoint.
65
+
66
+ Args:
67
+ endpoint_name: Name of the SageMaker endpoint.
68
+ model: Name of the ESM3 model.
69
+ """
70
+ # Dummy URL and token to make ESM3ForgeInferenceClient happy.
71
+ super().__init__(model=model, url="", token="dummy")
72
+
73
+ self._endpoint_name = endpoint_name
74
+ self._model = model
75
+
76
+ self._client = boto3.client(service_name="sagemaker-runtime")
77
+
78
+ def _post(self, endpoint, request, potential_sequence_of_concern):
79
+ request["potential_sequence_of_concern"] = potential_sequence_of_concern
80
+
81
+ invocations_request = {
82
+ # Duplicate these fields at the top level to make Forge requests consistent.
83
+ "model": request["model"],
84
+ "request_id": "", # Forge specific field.
85
+ "user_id": "", # Forge specific field.
86
+ # Invocation data bits.
87
+ "api_ver": "v1", # Must be v1 right now.
88
+ "endpoint": endpoint,
89
+ # Wrapped request.
90
+ endpoint: request,
91
+ }
92
+
93
+ try:
94
+ response = self._client.invoke_endpoint(
95
+ EndpointName=self._endpoint_name,
96
+ ContentType="application/json",
97
+ Body=json.dumps(invocations_request),
98
+ )
99
+ except Exception as e:
100
+ raise RuntimeError(f"Failure in {endpoint}: {e}")
101
+
102
+ data = json.loads(response["Body"].read().decode())
103
+
104
+ # Response must match request.
105
+ assert data["endpoint"] == endpoint
106
+
107
+ # Get the actual responses under the endpoint key.
108
+ data = data[endpoint]
109
+
110
+ return data
Dyna-1/esm/tokenization/__init__.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Protocol
3
+
4
+ from esm.utils.constants.models import (
5
+ ESM3_OPEN_SMALL,
6
+ normalize_model_name,
7
+ )
8
+
9
+ from .function_tokenizer import InterProQuantizedTokenizer
10
+ from .residue_tokenizer import ResidueAnnotationsTokenizer
11
+ from .sasa_tokenizer import SASADiscretizingTokenizer
12
+ from .sequence_tokenizer import EsmSequenceTokenizer
13
+ from .ss_tokenizer import SecondaryStructureTokenizer
14
+ from .structure_tokenizer import StructureTokenizer
15
+ from .tokenizer_base import EsmTokenizerBase
16
+
17
+
18
+ class TokenizerCollectionProtocol(Protocol):
19
+ sequence: EsmSequenceTokenizer
20
+ structure: StructureTokenizer
21
+ secondary_structure: SecondaryStructureTokenizer
22
+ sasa: SASADiscretizingTokenizer
23
+ function: InterProQuantizedTokenizer
24
+ residue_annotations: ResidueAnnotationsTokenizer
25
+
26
+
27
+ @dataclass
28
+ class TokenizerCollection:
29
+ sequence: EsmSequenceTokenizer
30
+ structure: StructureTokenizer
31
+ secondary_structure: SecondaryStructureTokenizer
32
+ sasa: SASADiscretizingTokenizer
33
+ function: InterProQuantizedTokenizer
34
+ residue_annotations: ResidueAnnotationsTokenizer
35
+
36
+
37
+ def get_esm3_model_tokenizers(model: str = ESM3_OPEN_SMALL) -> TokenizerCollection:
38
+ if normalize_model_name(model) == ESM3_OPEN_SMALL:
39
+ return TokenizerCollection(
40
+ sequence=EsmSequenceTokenizer(),
41
+ structure=StructureTokenizer(),
42
+ secondary_structure=SecondaryStructureTokenizer(kind="ss8"),
43
+ sasa=SASADiscretizingTokenizer(),
44
+ function=InterProQuantizedTokenizer(),
45
+ residue_annotations=ResidueAnnotationsTokenizer(),
46
+ )
47
+ else:
48
+ raise ValueError(f"Unknown model: {model}")
49
+
50
+
51
+ def get_esmc_model_tokenizers() -> EsmSequenceTokenizer:
52
+ return EsmSequenceTokenizer()
53
+
54
+
55
+ def get_invalid_tokenizer_ids(tokenizer: EsmTokenizerBase) -> list[int]:
56
+ if isinstance(tokenizer, EsmSequenceTokenizer):
57
+ return [
58
+ tokenizer.mask_token_id, # type: ignore
59
+ tokenizer.pad_token_id, # type: ignore
60
+ tokenizer.cls_token_id, # type: ignore
61
+ tokenizer.eos_token_id, # type: ignore
62
+ ]
63
+ else:
64
+ return [
65
+ tokenizer.mask_token_id,
66
+ tokenizer.pad_token_id,
67
+ tokenizer.bos_token_id,
68
+ tokenizer.eos_token_id,
69
+ ]
Dyna-1/esm/tokenization/function_tokenizer.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tokenizes annotations of protein function."""
2
+
3
+ import re
4
+ import string
5
+ from functools import cache, cached_property, partial
6
+ from typing import Collection
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import scipy.sparse as sp
11
+ import torch
12
+ import torch.nn.functional as F
13
+
14
+ from esm.tokenization.tokenizer_base import EsmTokenizerBase
15
+ from esm.utils.constants import esm3 as C
16
+ from esm.utils.function import interpro, lsh, tfidf
17
+ from esm.utils.misc import stack_variable_length_tensors
18
+ from esm.utils.types import FunctionAnnotation, PathLike
19
+
20
+
21
+ def _default_data_path(x: PathLike | None, d: PathLike) -> PathLike:
22
+ return x if x is not None else C.data_root("esm3") / d
23
+
24
+
25
+ def _default_local_data_path(x: PathLike | None, d: PathLike) -> PathLike:
26
+ return x if x is not None else d
27
+
28
+
29
+ class InterProQuantizedTokenizer(EsmTokenizerBase):
30
+ """Tokenizer for functional annotations.
31
+
32
+ This tokenizer converts InterPro and/or function keywords into a multi-token
33
+ representation by hashing TF-IDF vector representations of the text associated with
34
+ the fuction and then applying a locality sensitive hash (LSH).
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ depth: int = 8,
40
+ lsh_bits_per_token: int = 8,
41
+ lsh_path: PathLike | None = None,
42
+ keyword_vocabulary_path: PathLike | None = None,
43
+ keyword_idf_path: PathLike | None = None,
44
+ interpro_entry_path: PathLike | None = None,
45
+ interpro2keywords_path: PathLike | None = None,
46
+ ):
47
+ """Constructs function tokenizer.
48
+
49
+ Args:
50
+ depth: number of tokens emitted in each position.
51
+ lsh_bits_per_token: Number of LSH bits per token. Determines the vocabulary
52
+ size.
53
+ lsh_path: path to locality sensitive hash (LSH) hyperplanes.
54
+ keyword_vocabulary_path: path to csv containing function keyword vocabulary.
55
+ keyword_idf_path: path to IDF values for each keyword.
56
+ interpro_entry_csv_path: path to list of InterPro entries in CSV format.
57
+ interpro2keywords_path: path to CSV mapping InterPro IDs to function keywords.
58
+ """
59
+ self.depth = depth
60
+
61
+ self.keyword_vocabulary_path = _default_local_data_path(
62
+ keyword_vocabulary_path, C.KEYWORDS_VOCABULARY
63
+ )
64
+ self.keyword_idf_path = _default_local_data_path(
65
+ keyword_idf_path, C.KEYWORDS_IDF
66
+ )
67
+
68
+ self._interpro2keywords_path = _default_local_data_path(
69
+ interpro2keywords_path, C.INTERPRO2KEYWORDS
70
+ )
71
+ self.interpro_ = interpro.InterPro(
72
+ entries_path=_default_local_data_path(interpro_entry_path, C.INTERPRO_ENTRY)
73
+ )
74
+
75
+ self.lsh_path = lsh_path
76
+ self.lsh_bits_per_token = lsh_bits_per_token
77
+ self.lsh_vocab_size = 1 << lsh_bits_per_token
78
+
79
+ # This is the offset into the vocabulary where LSH tokens start.
80
+ self._lsh_token_vocab_offset = len(self.special_tokens) + 1 # +1 for <none>
81
+
82
+ @cached_property
83
+ def _lsh(self) -> lsh.LSHTokenized:
84
+ """Locality sensitive hash for function annotations."""
85
+ return lsh.LSHTokenized(
86
+ self.lsh_bits_per_token,
87
+ len(self.keyword_vocabulary),
88
+ self.depth,
89
+ _default_data_path(self.lsh_path, C.LSH_TABLE_PATHS["8bit"]),
90
+ )
91
+
92
+ @cached_property
93
+ def interpro2keywords(self) -> dict[str, list[str]]:
94
+ """Mapping from InterPro ID to function keywords."""
95
+ df = pd.read_csv(self._interpro2keywords_path)
96
+ assert "interpro_id" in df.columns and "keywords" in df.columns, df.columns
97
+ return dict(zip(df.interpro_id, df.keywords.str.split(",")))
98
+
99
+ @cached_property
100
+ def interpro_labels(self) -> list[str]:
101
+ """The set of supported InterPro labels."""
102
+ return sorted(self.interpro2keywords.keys())
103
+
104
+ @cached_property
105
+ def interpro_to_index(self) -> dict[str, int]:
106
+ """Mapping from InterPro id to index."""
107
+ return {id: i for i, id in enumerate(self.interpro_labels)}
108
+
109
+ @property
110
+ def keyword_vocabulary(self) -> list[str]:
111
+ """Set of supported keywords."""
112
+ return self._tfidf.vocabulary
113
+
114
+ @property
115
+ def keyword_to_index(self) -> dict[str, int]:
116
+ """Mapping from keywords to index."""
117
+ return self._tfidf.vocab_to_index
118
+
119
+ @cached_property
120
+ def _tfidf(self) -> tfidf.TFIDFModel:
121
+ """Creates TF-IDF model for encoding function keywords."""
122
+ return tfidf.TFIDFModel(
123
+ vocabulary_path=self.keyword_vocabulary_path, idf_path=self.keyword_idf_path
124
+ )
125
+
126
+ @cached_property
127
+ def special_tokens(self) -> list[str]:
128
+ """List of special tokens which come before cluster tokens in vocab."""
129
+ return ["<pad>", "<motif>", "<unk>"]
130
+
131
+ @cached_property
132
+ def vocab(self) -> list[str]:
133
+ """Vocabulary of function tokens."""
134
+ lsh_tokens = [f"<lsh:{i}>" for i in range(self.lsh_vocab_size)]
135
+ return self.special_tokens + ["<none>"] + lsh_tokens
136
+
137
+ @cached_property
138
+ def vocab_to_index(self) -> dict[str, int]:
139
+ return {token: token_id for token_id, token in enumerate(self.vocab)}
140
+
141
+ def get_special_tokens_mask(self, encoded: torch.Tensor) -> torch.Tensor:
142
+ """Determines where in the sequence are special tokens."""
143
+ where = encoded < len(self.special_tokens)
144
+ assert torch.all(torch.all(where, dim=1) | torch.all(~where, dim=1))
145
+ return where[:, 0]
146
+
147
+ def tokenize(
148
+ self,
149
+ annotations: list[FunctionAnnotation],
150
+ seqlen: int,
151
+ p_keyword_dropout: float = 0.0,
152
+ ) -> list[str]:
153
+ """Encodes range-annotations of protein function as tokens.
154
+
155
+ Args:
156
+ features: Annotated function ranges, either as InterPro ids or keywords.
157
+ seqlen: length of sequence.
158
+ p_keyword_dropout: Optional probability of dropping out keywords from the
159
+ input annotations.
160
+ Returns:
161
+ Tokenized representation of function annotations as a list of string tokens
162
+ of size seqlen.
163
+ """
164
+ assert seqlen >= 0
165
+
166
+ if not annotations:
167
+ return ["<pad>"] * seqlen
168
+
169
+ # Expand the range annotations into positional annotaiton sets.
170
+ positional_labels: list[set[str]] = [set() for _ in range(seqlen)]
171
+ for annotation in annotations:
172
+ assert 1 <= annotation.start <= annotation.end <= seqlen, (
173
+ f"Invalid annotation range [{annotation.start}, {annotation.end}] for "
174
+ f"sequence length {seqlen}."
175
+ )
176
+ for i in range(annotation.start - 1, annotation.end):
177
+ positional_labels[i].add(annotation.label)
178
+
179
+ if p_keyword_dropout > 0:
180
+ keyword_mask = (
181
+ np.random.random(len(self._tfidf.vocabulary)) < p_keyword_dropout
182
+ )
183
+ else:
184
+ keyword_mask = None
185
+
186
+ # Annotations tend to be repetitive over the length of the sequence - cache their
187
+ # hashes to speed up tokenization.
188
+ hash_fn = cache(partial(self._function_text_hash, keyword_mask=keyword_mask))
189
+
190
+ tokens: list[str] = []
191
+ for labels in positional_labels:
192
+ if not labels:
193
+ token = "<none>"
194
+ else:
195
+ lsh_hash = hash_fn(frozenset(labels))
196
+ if lsh_hash is not None:
197
+ assert len(lsh_hash) == self.depth
198
+ token = "<lsh:" + ",".join(map(str, lsh_hash)) + ">"
199
+ else:
200
+ token = "<unk>"
201
+
202
+ tokens.append(token)
203
+
204
+ return tokens
205
+
206
+ def _function_text_hash(
207
+ self, labels: Collection[str], keyword_mask: np.ndarray | None = None
208
+ ) -> np.ndarray | None:
209
+ """Applies a locality sensitive hash (LSH) to function text.
210
+
211
+ Args:
212
+ labels: InterPro ids and/or keywords.
213
+ keyword_mask: optional boolean array shaped (keyword_vocab_size,) indicating
214
+ which keywords to drop before hashing.
215
+ Returns:
216
+ LSH shaped (depth,) or None if there is no text or keywords to hash.
217
+ """
218
+ # Split labels into either InterPro ids or keywords.
219
+ interpro_ids = []
220
+ keywords = []
221
+ for label in labels:
222
+ match = re.search(r"IPR\d+", label)
223
+ if match and match.group() in self.interpro_to_index:
224
+ interpro_ids.append(match.group())
225
+ elif label in self._tfidf.vocab_to_index:
226
+ keywords.append(label)
227
+ else:
228
+ raise ValueError(f"Unsupported: {label}")
229
+
230
+ vec: sp.csr_matrix = self._tfidf.encode(keywords)
231
+
232
+ # Perform an element-wise maximum over TF-IDF vectors from distinct tags to
233
+ # avoid tags getting "washed out" by eg. 4 very similar tags. Keywords are
234
+ # incorporated as another TF-IDF vector
235
+ vec: sp.csr_matrix = self._tfidf.encode(keywords)
236
+ for interpro_id in interpro_ids:
237
+ interpro_keywords = self.interpro2keywords.get(interpro_id, [])
238
+ vec_ = self._tfidf.encode(interpro_keywords)
239
+ vec = vec.maximum(vec_)
240
+
241
+ if keyword_mask is not None:
242
+ vec.data *= 1 - np.take(keyword_mask, vec.indices)
243
+
244
+ if vec.sum() == 0:
245
+ return None
246
+
247
+ return self._lsh(vec)[0, :]
248
+
249
+ def encode(
250
+ self, tokens: list[str], add_special_tokens: bool = True
251
+ ) -> torch.Tensor:
252
+ """Encodes string tokens as token-id tensor.
253
+
254
+ Args:
255
+ tokens: list of individual tokens. e.g. ["<none>", "<pq:1,2,3,4>"]
256
+ add_special_tokens: whether to add a single pad token at the start and end
257
+ of the sequence to act as <cls> and <eos> tokens.
258
+ Returns:
259
+ <int>[length, depth] function tokens. Length will be +2 of input tokens
260
+ length when add_special_tokens is True.
261
+ """
262
+ token_ids = torch.zeros(size=(len(tokens), self.depth), dtype=torch.int64)
263
+ for i, token in enumerate(tokens):
264
+ token_ids[i, :] = torch.tensor(self._token2ids(token))
265
+ if add_special_tokens:
266
+ token_ids = F.pad(
267
+ token_ids, (0, 0, 1, 1), value=self.vocab_to_index["<pad>"]
268
+ )
269
+ return token_ids
270
+
271
+ def lookup_annotation_name(self, annotation: FunctionAnnotation) -> str | None:
272
+ return self.interpro_.lookup_name(annotation.label)
273
+
274
+ def format_annotation(self, annotation: FunctionAnnotation) -> str:
275
+ annotation_name = self.lookup_annotation_name(annotation)
276
+ if annotation_name is not None:
277
+ return f"{annotation_name} ({annotation.label})"
278
+ else:
279
+ return annotation.label
280
+
281
+ def _token2ids(self, token: str) -> list[int]:
282
+ """Converts token into token_id set of length depth."""
283
+ if re.match(r"<lsh:[\d+,]+>", token):
284
+ lsh_ids = [int(lsh_id) for lsh_id in re.findall(r"\d+", token)]
285
+ assert (
286
+ len(lsh_ids) == self.depth
287
+ ), f"Expected token to have {self.depth} ids found {lsh_ids}"
288
+ return [self._lsh_token_vocab_offset + lsh_id for lsh_id in lsh_ids]
289
+ elif token == "<none>" or token in self.special_tokens:
290
+ return [self.vocab_to_index[token]] * self.depth
291
+ else:
292
+ raise ValueError(f"Unknown token: {token}")
293
+
294
+ def batch_encode(
295
+ self, token_batch: list[list[str]], add_special_tokens: bool = True
296
+ ) -> torch.Tensor:
297
+ """Encodes batch of function tokens.
298
+
299
+ Args:
300
+ token_batch: batch of function tokens.
301
+ add_special_tokens: whether to add special tokens.
302
+ Returns:
303
+ <int>[batch_size, max_length, depth] batch of encoded tokens.
304
+ """
305
+ encoded = [
306
+ self.encode(tokens, add_special_tokens=add_special_tokens)
307
+ for tokens in token_batch
308
+ ]
309
+ return stack_variable_length_tensors(
310
+ encoded, constant_value=self.vocab_to_index["<pad>"]
311
+ )
312
+
313
+ def decode(self, encoded: torch.Tensor):
314
+ raise NotImplementedError(
315
+ "Function token decoding should be handled with "
316
+ "util.decoding.decode_function_annotations"
317
+ )
318
+
319
+ @property
320
+ def mask_token(self) -> str:
321
+ return "<pad>"
322
+
323
+ @property
324
+ def mask_token_id(self) -> int:
325
+ return self.vocab_to_index[self.mask_token]
326
+
327
+ @property
328
+ def bos_token(self) -> str:
329
+ return "<pad>"
330
+
331
+ @property
332
+ def bos_token_id(self) -> int:
333
+ return self.vocab_to_index[self.bos_token]
334
+
335
+ @property
336
+ def eos_token(self) -> str:
337
+ return "<pad>"
338
+
339
+ @property
340
+ def eos_token_id(self) -> int:
341
+ return self.vocab_to_index[self.eos_token]
342
+
343
+ @property
344
+ def pad_token(self) -> str:
345
+ return "<pad>"
346
+
347
+ @property
348
+ def pad_token_id(self) -> int:
349
+ return self.vocab_to_index[self.pad_token]
350
+
351
+ @property
352
+ def chain_break_token(self) -> str:
353
+ return "<pad>"
354
+
355
+ @property
356
+ def chain_break_token_id(self) -> int:
357
+ return self.vocab_to_index[self.chain_break_token]
358
+
359
+ @property
360
+ def all_token_ids(self):
361
+ return list(range(len(self.vocab)))
362
+
363
+ @property
364
+ def special_token_ids(self):
365
+ return [self.vocab_to_index[token] for token in self.special_tokens]
366
+
367
+
368
+ def _texts_to_keywords(texts: list[str]) -> list[str]:
369
+ """Breaks InterPro/GO free-text description set into bag-of-n-grams for n={1,2}.
370
+
371
+ Args:
372
+ texts: collection of text descriptions, i.e. InterPro/GO names.
373
+ Returns:
374
+ Collection of terms/n-grams
375
+ """
376
+ keywords = []
377
+ for text in texts:
378
+ keywords.extend(_keywords_from_text(text))
379
+ return keywords
380
+
381
+
382
+ def _keywords_from_text(text: str) -> list[str]:
383
+ """Splits text into unigrams and bigrams."""
384
+ elements = text.split(", ")
385
+
386
+ terms = []
387
+ for element in elements:
388
+ element = _sanitize(element)
389
+ words = element.split()
390
+
391
+ # Add 1-mers
392
+ terms.extend(words)
393
+
394
+ # Add 2-mers
395
+ for i in range(len(words) - 1):
396
+ bigram = words[i] + " " + words[i + 1]
397
+ terms.append(bigram)
398
+
399
+ return [term for term in terms if len(term) > 1 and term not in _EXCLUDED_TERMS]
400
+
401
+
402
+ def _sanitize(text: str) -> str:
403
+ text = text.replace("-", " ")
404
+ text = text.translate(str.maketrans("", "", string.punctuation))
405
+ text = text.lower()
406
+ return text
407
+
408
+
409
+ # These terms are omitted from textual representations since they are pervasive and
410
+ # unspecific to particular protein function.
411
+ _EXCLUDED_TERMS = {
412
+ "binding domain",
413
+ "biological_process",
414
+ "biological process",
415
+ "biologicalprocess",
416
+ "c",
417
+ "cellular_component",
418
+ "cellular component",
419
+ "cellularcomponent",
420
+ "cellular_process",
421
+ "cellularprocess",
422
+ "cellular process",
423
+ "cellularprocess",
424
+ "like domain",
425
+ "molecular function",
426
+ "molecular_function",
427
+ "molecularfunction",
428
+ "n",
429
+ }
Dyna-1/esm/tokenization/residue_tokenizer.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import cached_property
2
+ from typing import Any
3
+
4
+ import pandas as pd
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from cloudpathlib import AnyPath
8
+
9
+ from esm.tokenization.tokenizer_base import EsmTokenizerBase
10
+ from esm.utils.constants import esm3 as C
11
+
12
+ Sample = dict[str, Any]
13
+
14
+
15
+ class ResidueAnnotationsTokenizer(EsmTokenizerBase):
16
+ def __init__(self, csv_path: str | None = None, max_annotations: int = 16):
17
+ if csv_path is None:
18
+ csv_path = str(C.data_root("esm3") / C.RESID_CSV)
19
+ self.csv_path = csv_path
20
+ self.max_annotations = max_annotations
21
+
22
+ @cached_property
23
+ def _description2label(self) -> dict[str, str]:
24
+ with AnyPath(self.csv_path).open() as f: # type: ignore
25
+ df = pd.read_csv(f)
26
+ return dict(zip(df.label, df.label_clean))
27
+
28
+ @cached_property
29
+ def _labels(self) -> list[str]:
30
+ with AnyPath(self.csv_path).open() as f: # type: ignore
31
+ df = pd.read_csv(f)
32
+ labels = (
33
+ df.groupby("label_clean")["count"]
34
+ .sum()
35
+ .sort_values(ascending=False, kind="stable") # type: ignore
36
+ .index.tolist()
37
+ )
38
+ assert isinstance(labels, list)
39
+ return labels # type: ignore
40
+
41
+ def _description2id(self, description: str) -> int | None:
42
+ label = self._description2label.get(description)
43
+ return self._label2id.get(label) # type: ignore
44
+
45
+ @cached_property
46
+ def _label2id(self) -> dict[str, int]:
47
+ offset = len(self.special_tokens) + 1 # +1 for "<none>"
48
+ return {label: offset + i for i, label in enumerate(self._labels)}
49
+
50
+ @cached_property
51
+ def special_tokens(self) -> list[str]:
52
+ """List of special tokens which come before cluster toknes in vocab."""
53
+ return ["<pad>", "<motif>", "<unk>"]
54
+
55
+ @cached_property
56
+ def vocab(self):
57
+ annotation_tokens = [f"<ra:{id}>" for _, id in self._label2id.items()]
58
+ return self.special_tokens + ["<none>"] + annotation_tokens
59
+
60
+ @cached_property
61
+ def vocab_to_index(self) -> dict[str, int]:
62
+ return {token: token_id for token_id, token in enumerate(self.vocab)}
63
+
64
+ @cached_property
65
+ def vocabulary(self) -> list[str]:
66
+ """Full vocabulary."""
67
+ return [*self.special_tokens, "<none>", *self._labels]
68
+
69
+ def get_special_tokens_mask(self, encoded: torch.Tensor) -> torch.Tensor:
70
+ """Determines where in the sequence are special tokens."""
71
+ return encoded[:, 0] < len(self.special_tokens)
72
+
73
+ def tokenize(
74
+ self, sample: Sample | None, sequence: str, fail_on_mismatch: bool = False
75
+ ) -> list[str]:
76
+ """
77
+ # interpro_site_starts
78
+ # interpro_site_ends # should always == interpro_site_starts. but I haven't checked overall.
79
+ # interpro_site_residues # the residue identity of the specfic residue that is annotated. good for a sanity check that parsing occurred correctly.
80
+ # interpro_site_descriptions
81
+ # ASSERT (i.e. drop if bad)
82
+ # interpro_site_residues matches the residue at that position
83
+ # all these lists ^ above are the same length
84
+ """
85
+ seqlen = len(sequence)
86
+ assert seqlen >= 0
87
+ # None mean sequence is *not annotated* - so use full <pad>
88
+ if sample is None:
89
+ return ["<pad>"] * seqlen
90
+
91
+ if any(
92
+ sample.get(field) is None
93
+ for field in [
94
+ "interpro_site_descriptions",
95
+ "interpro_site_starts",
96
+ "interpro_site_ends",
97
+ "interpro_site_residues",
98
+ ]
99
+ ):
100
+ return ["<pad>"] * seqlen
101
+
102
+ num_annotations = len(sample["interpro_site_descriptions"])
103
+ if any(
104
+ len(sample[field]) != num_annotations
105
+ for field in [
106
+ "interpro_site_starts",
107
+ "interpro_site_ends",
108
+ "interpro_site_residues",
109
+ ]
110
+ ):
111
+ # mismatched length.
112
+ return ["<pad>"] * seqlen
113
+
114
+ positional_ids = [set() for _ in range(seqlen)]
115
+ for description, start, end, residues in zip(
116
+ sample["interpro_site_descriptions"],
117
+ sample["interpro_site_starts"],
118
+ sample["interpro_site_ends"],
119
+ sample["interpro_site_residues"],
120
+ ):
121
+ try:
122
+ start = int(start)
123
+ end = int(end)
124
+ except (TypeError, ValueError):
125
+ continue
126
+
127
+ # Start / End are 1-indexed [inclusive, inclusive].
128
+ if start <= 0 or end > seqlen or start > end:
129
+ print(f"invalid start/end: ({start}, {end}), len: {seqlen}")
130
+ continue
131
+
132
+ if len(residues) != (end - start) + 1:
133
+ print(f"bad reference residue: {residues}")
134
+ continue
135
+
136
+ token_id = self._description2id(description)
137
+ if token_id is None:
138
+ token_id = self.vocab_to_index["<unk>"]
139
+
140
+ for i, residue in zip(range(start - 1, end), residues):
141
+ # If there are any mismatching residues, skip the entire sample.
142
+ if sequence[i] != residue:
143
+ if fail_on_mismatch:
144
+ raise ValueError(
145
+ f"Residue mismatch at position {i} (1-indexed): {sequence[i]} != {residue}"
146
+ )
147
+ return ["<pad>"] * seqlen
148
+
149
+ positional_ids[i].add(token_id)
150
+
151
+ tokens = []
152
+ for token_ids in positional_ids:
153
+ if token_ids:
154
+ token = "<ra:" + ",".join(str(token_id) for token_id in token_ids) + ">"
155
+ else:
156
+ token = "<none>"
157
+ tokens.append(token)
158
+ return tokens
159
+
160
+ def _token2ids(self, token: str) -> list[int]:
161
+ if token.startswith("<ra:") and token.endswith(">"):
162
+ return [int(token_id) for token_id in token[4:-1].split(",")]
163
+ else:
164
+ token_id = self.vocab_to_index[token]
165
+ return [token_id]
166
+
167
+ def encode(
168
+ self, tokens: list[str], add_special_tokens: bool = True
169
+ ) -> torch.Tensor:
170
+ token_ids = torch.full(
171
+ size=(len(tokens), self.max_annotations),
172
+ dtype=torch.int64,
173
+ fill_value=self.vocab_to_index["<pad>"],
174
+ )
175
+ for i, token in enumerate(tokens):
176
+ ids = self._token2ids(token)[: self.max_annotations]
177
+ token_ids[i, : len(ids)] = torch.tensor(ids)
178
+
179
+ if add_special_tokens:
180
+ token_ids = F.pad(
181
+ token_ids, (0, 0, 1, 1), value=self.vocab_to_index["<pad>"]
182
+ )
183
+ return token_ids
184
+
185
+ def decode(self, encoded: torch.Tensor) -> list[str]:
186
+ raise NotImplementedError(
187
+ "Residue annotation decoding should be handled with util.decoding.decode_residue_annotations"
188
+ )
189
+
190
+ @property
191
+ def mask_token(self) -> str:
192
+ return "<pad>"
193
+
194
+ @property
195
+ def mask_token_id(self) -> int:
196
+ return self.vocab_to_index[self.mask_token]
197
+
198
+ @property
199
+ def bos_token(self) -> str:
200
+ return "<pad>"
201
+
202
+ @property
203
+ def bos_token_id(self) -> int:
204
+ return self.vocab_to_index[self.bos_token]
205
+
206
+ @property
207
+ def eos_token(self) -> str:
208
+ return "<pad>"
209
+
210
+ @property
211
+ def eos_token_id(self) -> int:
212
+ return self.vocab_to_index[self.eos_token]
213
+
214
+ @property
215
+ def pad_token(self) -> str:
216
+ return "<pad>"
217
+
218
+ @property
219
+ def pad_token_id(self) -> int:
220
+ return self.vocab_to_index[self.pad_token]
221
+
222
+ @property
223
+ def chain_break_token(self) -> str:
224
+ return "<pad>"
225
+
226
+ @property
227
+ def chain_break_token_id(self) -> int:
228
+ return self.vocab_to_index[self.chain_break_token]
229
+
230
+ @property
231
+ def all_token_ids(self):
232
+ return list(range(len(self.vocab)))
233
+
234
+ @property
235
+ def special_token_ids(self):
236
+ return [self.vocab_to_index[token] for token in self.special_tokens]
Dyna-1/esm/tokenization/sasa_tokenizer.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import cached_property
2
+
3
+ import torch
4
+
5
+ from esm.tokenization.tokenizer_base import EsmTokenizerBase
6
+ from esm.utils.constants import esm3 as C
7
+
8
+
9
+ class SASADiscretizingTokenizer(EsmTokenizerBase):
10
+ """Tokenizer for Solvent Accessible Surface Area (SASA)."""
11
+
12
+ def __init__(self, boundaries: list[float] = C.SASA_DISCRETIZATION_BOUNDARIES):
13
+ self._boundaries = sorted(boundaries)
14
+
15
+ @cached_property
16
+ def special_tokens(self) -> list[str]:
17
+ return ["<pad>", "<motif>", "<unk>"]
18
+
19
+ @cached_property
20
+ def vocab(self) -> list[str]:
21
+ """Discrete token vocabulary.
22
+
23
+ Returns:
24
+ token vocabulary with ranges represented as "<low-high>".
25
+ """
26
+ boundary_strs = ["0"] + [str(b) for b in self._boundaries] + ["inf"]
27
+ range_tokens = [
28
+ f"<{low}-{high}>"
29
+ for low, high in zip(boundary_strs[:-1], boundary_strs[1:])
30
+ ]
31
+ return self.special_tokens + range_tokens
32
+
33
+ @cached_property
34
+ def midpoints_tensor(self) -> torch.Tensor:
35
+ """Midpoints of the SASA token ranges."""
36
+ boundaries = [0] + self._boundaries + [self._boundaries[-1] * 2]
37
+ midpoint_tokens = [
38
+ (float(high) + float(low)) / 2
39
+ for low, high in zip(boundaries[:-1], boundaries[1:])
40
+ ]
41
+ midpoint_tokens = [float("nan"), float("nan"), float("nan")] + midpoint_tokens
42
+ return torch.Tensor(midpoint_tokens)
43
+
44
+ def midpoints(self) -> list[float]:
45
+ """Midpoints of the SASA token ranges."""
46
+ return self.midpoints_tensor.tolist()
47
+
48
+ @cached_property
49
+ def vocab_to_index(self) -> dict[str, int]:
50
+ """Constructs token -> token id mapping."""
51
+ return {word: i for i, word in enumerate(self.vocab)}
52
+
53
+ def get_special_tokens_mask(self, tokens: torch.Tensor) -> torch.Tensor:
54
+ """Determines which positions are special tokens.
55
+
56
+ Args:
57
+ tokens: <int>[length]
58
+ Returns:
59
+ <bool>[length] tensor, true where special tokens are located in the input.
60
+ """
61
+ return tokens < len(self.special_tokens)
62
+
63
+ def encode(
64
+ self, values: list[float | str], add_special_tokens: bool = True
65
+ ) -> torch.Tensor:
66
+ """Encodes SASA values as discrete tokens.
67
+
68
+ Args:
69
+ values: list of either SASA values or individual tokens. For example
70
+ [1.2, "<pad>", 10.3, <pad>, 0.]
71
+ Returns:
72
+ Token ids as tensor. Adds BOS and EOS special tokens.
73
+ """
74
+ ids = []
75
+ if add_special_tokens:
76
+ ids.append(self.vocab_to_index["<pad>"]) # BOS
77
+ for value in values:
78
+ if isinstance(value, (float, int)):
79
+ bucket = torch.bucketize(value, torch.tensor(self._boundaries))
80
+ token_id = len(self.special_tokens) + bucket
81
+ elif isinstance(value, str):
82
+ token_id = self.vocab_to_index[value]
83
+ else:
84
+ raise TypeError(value)
85
+ ids.append(token_id)
86
+ if add_special_tokens:
87
+ ids.append(self.vocab_to_index["<pad>"]) # EOS
88
+
89
+ return torch.tensor(ids, dtype=torch.int64)
90
+
91
+ def decode_float(self, encoded: torch.Tensor) -> list[float]:
92
+ """Decodes SASA token ids into float values."""
93
+ decoded = self.midpoints_tensor[encoded.cpu()]
94
+ nan_mask = torch.isnan(decoded)
95
+ np_arr = decoded.numpy()
96
+ np_arr[nan_mask.numpy()] = None
97
+ return np_arr.tolist()
98
+
99
+ def decode(self, encoded: torch.Tensor) -> str:
100
+ """Decodes SASA token ids."""
101
+ return ",".join(self.vocab[i] for i in encoded)
102
+
103
+ def decode_list(self, encoded: torch.Tensor) -> list[str]:
104
+ """Decodes SASA token ids."""
105
+ return [self.vocab[i] for i in encoded]
106
+
107
+ @property
108
+ def mask_token(self) -> str:
109
+ return "<pad>"
110
+
111
+ @property
112
+ def mask_token_id(self) -> int:
113
+ return self.vocab_to_index[self.mask_token]
114
+
115
+ @property
116
+ def bos_token(self) -> str:
117
+ return "<pad>"
118
+
119
+ @property
120
+ def bos_token_id(self) -> int:
121
+ return self.vocab_to_index[self.bos_token]
122
+
123
+ @property
124
+ def eos_token(self) -> str:
125
+ return "<pad>"
126
+
127
+ @property
128
+ def eos_token_id(self) -> int:
129
+ return self.vocab_to_index[self.eos_token]
130
+
131
+ @property
132
+ def pad_token(self) -> str:
133
+ return "<pad>"
134
+
135
+ @property
136
+ def pad_token_id(self) -> int:
137
+ return self.vocab_to_index[self.pad_token]
138
+
139
+ @property
140
+ def chain_break_token(self) -> str:
141
+ return "<pad>"
142
+
143
+ @property
144
+ def chain_break_token_id(self) -> int:
145
+ return self.vocab_to_index[self.chain_break_token]
146
+
147
+ @property
148
+ def all_token_ids(self):
149
+ return list(range(len(self.vocab)))
150
+
151
+ @property
152
+ def special_token_ids(self):
153
+ return [self.vocab_to_index[token] for token in self.special_tokens]
Dyna-1/esm/tokenization/sequence_tokenizer.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tokenizers import Tokenizer
2
+ from tokenizers.models import BPE
3
+ from tokenizers.processors import TemplateProcessing
4
+ from transformers import PreTrainedTokenizerFast
5
+
6
+ from esm.tokenization.tokenizer_base import EsmTokenizerBase
7
+ from esm.utils.constants import esm3 as C
8
+
9
+
10
+ class EsmSequenceTokenizer(PreTrainedTokenizerFast, EsmTokenizerBase):
11
+ """
12
+ Constructs an ESM tokenizer.
13
+ """
14
+
15
+ model_input_names = ["sequence_tokens", "attention_mask"]
16
+
17
+ def __init__(
18
+ self,
19
+ unk_token="<unk>",
20
+ cls_token="<cls>",
21
+ pad_token="<pad>",
22
+ mask_token="<mask>",
23
+ eos_token="<eos>",
24
+ chain_break_token="|",
25
+ **kwargs,
26
+ ):
27
+ all_tokens = C.SEQUENCE_VOCAB
28
+ token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)}
29
+
30
+ # a character-level tokenizer is the same as BPE with no token merges
31
+ bpe = BPE(token_to_id, merges=[], unk_token=unk_token)
32
+ tokenizer = Tokenizer(bpe)
33
+ special_tokens = [
34
+ cls_token,
35
+ pad_token,
36
+ mask_token,
37
+ eos_token,
38
+ chain_break_token,
39
+ ]
40
+ self.cb_token = chain_break_token
41
+ additional_special_tokens = [chain_break_token]
42
+
43
+ tokenizer.add_special_tokens(special_tokens)
44
+
45
+ # This is where we configure the automatic addition of special tokens when we call
46
+ # tokenizer(text, add_special_tokens=True). Note that you can also configure how two
47
+ # sequences are merged if you want.
48
+ tokenizer.post_processor = TemplateProcessing( # type: ignore
49
+ single="<cls> $A <eos>",
50
+ special_tokens=[
51
+ ("<cls>", tokenizer.token_to_id("<cls>")),
52
+ ("<eos>", tokenizer.token_to_id("<eos>")),
53
+ ],
54
+ )
55
+ super().__init__(
56
+ tokenizer_object=tokenizer,
57
+ unk_token=unk_token,
58
+ cls_token=cls_token,
59
+ pad_token=pad_token,
60
+ mask_token=mask_token,
61
+ eos_token=eos_token,
62
+ additional_special_tokens=additional_special_tokens,
63
+ **kwargs,
64
+ )
65
+
66
+ # These are a footgun, we never use the `bos` token anywhere so we're just overriding it here.
67
+ @property
68
+ def bos_token(self):
69
+ return self.cls_token
70
+
71
+ @property
72
+ def bos_token_id(self):
73
+ return self.cls_token_id
74
+
75
+ @property
76
+ def chain_break_token(self):
77
+ return self.cb_token
78
+
79
+ @property
80
+ def chain_break_token_id(self):
81
+ return self.convert_tokens_to_ids(self.chain_break_token)
82
+
83
+ @property
84
+ def all_token_ids(self):
85
+ return list(range(self.vocab_size))
86
+
87
+ @property
88
+ def special_token_ids(self):
89
+ return self.all_special_ids
Dyna-1/esm/tokenization/ss_tokenizer.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import cached_property
2
+ from typing import Sequence
3
+
4
+ import torch
5
+
6
+ from esm.tokenization.tokenizer_base import EsmTokenizerBase
7
+ from esm.utils.constants import esm3 as C
8
+
9
+
10
+ class SecondaryStructureTokenizer(EsmTokenizerBase):
11
+ """Tokenizer for secondary structure strings."""
12
+
13
+ def __init__(self, kind: str = "ss8"):
14
+ assert kind in ("ss8", "ss3")
15
+ self.kind = kind
16
+
17
+ @property
18
+ def special_tokens(self) -> list[str]:
19
+ return ["<pad>", "<motif>", "<unk>"]
20
+
21
+ @cached_property
22
+ def vocab(self):
23
+ """Tokenzier vocabulary list."""
24
+ match self.kind:
25
+ case "ss8":
26
+ nonspecial_tokens = list(C.SSE_8CLASS_VOCAB) # "GHITEBSC"
27
+ case "ss3":
28
+ nonspecial_tokens = list(C.SSE_3CLASS_VOCAB) # HEC
29
+ case _:
30
+ raise ValueError(self.kind)
31
+
32
+ # The non-special tokens ids match amino acid tokens ids when possible.
33
+ return [*self.special_tokens, *nonspecial_tokens]
34
+
35
+ @cached_property
36
+ def vocab_to_index(self) -> dict[str, int]:
37
+ """Constructs token -> token id mapping."""
38
+ return {word: i for i, word in enumerate(self.vocab)}
39
+
40
+ def get_special_tokens_mask(self, tokens: torch.Tensor) -> torch.Tensor:
41
+ """Determines which positions are special tokens.
42
+
43
+ Args:
44
+ tokens: <int>[length]
45
+ Returns:
46
+ <bool>[length] tensor, true where special tokens are located in the input.
47
+ """
48
+ return tokens < len(self.special_tokens)
49
+
50
+ def encode(
51
+ self, sequence: str | Sequence[str], add_special_tokens: bool = True
52
+ ) -> torch.Tensor:
53
+ """Encode secondary structure string
54
+
55
+ Args:
56
+ string: secondary structure string e.g. "GHHIT", or as token listk.
57
+ Returns:
58
+ <int>[sequence_length] token ids representing. Will add <cls>/<eos>.
59
+ """
60
+ ids = []
61
+ if add_special_tokens:
62
+ ids.append(self.vocab_to_index["<pad>"]) # cls
63
+ for char in sequence:
64
+ ids.append(self.vocab_to_index[char])
65
+ if add_special_tokens:
66
+ ids.append(self.vocab_to_index["<pad>"]) # eos
67
+ return torch.tensor(ids, dtype=torch.int64)
68
+
69
+ def decode(self, encoded: torch.Tensor) -> str:
70
+ """Decodes token ids into secondary structure string.
71
+
72
+ Args:
73
+ encoded: <int>[length] token id array.
74
+ Returns
75
+ Decoded secondary structure string.
76
+ """
77
+ return "".join(self.vocab[i] for i in encoded)
78
+
79
+ @property
80
+ def mask_token(self) -> str:
81
+ return "<pad>"
82
+
83
+ @property
84
+ def mask_token_id(self) -> int:
85
+ return self.vocab_to_index[self.mask_token]
86
+
87
+ @property
88
+ def bos_token(self) -> str:
89
+ return "<pad>"
90
+
91
+ @property
92
+ def bos_token_id(self) -> int:
93
+ return self.vocab_to_index[self.bos_token]
94
+
95
+ @property
96
+ def eos_token(self) -> str:
97
+ return "<pad>"
98
+
99
+ @property
100
+ def eos_token_id(self) -> int:
101
+ return self.vocab_to_index[self.eos_token]
102
+
103
+ @property
104
+ def pad_token(self) -> str:
105
+ return "<pad>"
106
+
107
+ @property
108
+ def pad_token_id(self) -> int:
109
+ return self.vocab_to_index[self.pad_token]
110
+
111
+ @property
112
+ def chain_break_token(self) -> str:
113
+ return "<pad>"
114
+
115
+ @property
116
+ def chain_break_token_id(self) -> int:
117
+ return self.vocab_to_index[self.chain_break_token]
118
+
119
+ @property
120
+ def all_token_ids(self):
121
+ return list(range(len(self.vocab)))
122
+
123
+ @property
124
+ def special_token_ids(self):
125
+ return [self.vocab_to_index[token] for token in self.special_tokens]
Dyna-1/esm/tokenization/structure_tokenizer.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from esm.tokenization.tokenizer_base import EsmTokenizerBase
2
+ from esm.utils.constants import esm3 as C
3
+
4
+
5
+ class StructureTokenizer(EsmTokenizerBase):
6
+ """A convenince class for accessing special token ids of
7
+ the StructureTokenEncoder and StructureTokenDecoder."""
8
+
9
+ def __init__(self, codebook_size: int = C.VQVAE_CODEBOOK_SIZE):
10
+ self.vq_vae_special_tokens = {
11
+ "MASK": codebook_size,
12
+ "EOS": codebook_size + 1,
13
+ "BOS": codebook_size + 2,
14
+ "PAD": codebook_size + 3,
15
+ "CHAINBREAK": codebook_size + 4,
16
+ }
17
+
18
+ def mask_token(self) -> str:
19
+ raise NotImplementedError(
20
+ "Structure tokens are defined on 3D coordinates, not strings."
21
+ )
22
+
23
+ @property
24
+ def mask_token_id(self) -> int:
25
+ return self.vq_vae_special_tokens["MASK"]
26
+
27
+ def bos_token(self) -> str:
28
+ raise NotImplementedError(
29
+ "Structure tokens are defined on 3D coordinates, not strings."
30
+ )
31
+
32
+ @property
33
+ def bos_token_id(self) -> int:
34
+ return self.vq_vae_special_tokens["BOS"]
35
+
36
+ def eos_token(self) -> str:
37
+ raise NotImplementedError(
38
+ "Structure tokens are defined on 3D coordinates, not strings."
39
+ )
40
+
41
+ @property
42
+ def eos_token_id(self) -> int:
43
+ return self.vq_vae_special_tokens["EOS"]
44
+
45
+ def pad_token(self) -> str:
46
+ raise NotImplementedError(
47
+ "Structure tokens are defined on 3D coordinates, not strings."
48
+ )
49
+
50
+ @property
51
+ def pad_token_id(self) -> int:
52
+ return self.vq_vae_special_tokens["PAD"]
53
+
54
+ def chain_break_token(self) -> str:
55
+ raise NotImplementedError(
56
+ "Structure tokens are defined on 3D coordinates, not strings."
57
+ )
58
+
59
+ @property
60
+ def chain_break_token_id(self) -> int:
61
+ return self.vq_vae_special_tokens["CHAINBREAK"]
62
+
63
+ @property
64
+ def all_token_ids(self):
65
+ return list(range(C.VQVAE_CODEBOOK_SIZE + len(self.vq_vae_special_tokens)))
66
+
67
+ @property
68
+ def special_token_ids(self):
69
+ return self.vq_vae_special_tokens.values()
70
+
71
+ def encode(self, *args, **kwargs):
72
+ raise NotImplementedError(
73
+ "The StructureTokenizer class is provided as a convenience for "
74
+ "accessing special token ids of the StructureTokenEncoder and StructureTokenDecoder.\n"
75
+ "Please use them instead."
76
+ )
77
+
78
+ def decode(self, *args, **kwargs):
79
+ raise NotImplementedError(
80
+ "The StructureTokenizer class is provided as a convenience for "
81
+ "accessing special token ids of the StructureTokenEncoder and StructureTokenDecoder.\n"
82
+ "Please use them instead."
83
+ )
Dyna-1/esm/tokenization/tokenizer_base.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Protocol, runtime_checkable
2
+
3
+
4
+ @runtime_checkable
5
+ class EsmTokenizerBase(Protocol):
6
+ def encode(self, *args, **kwargs): ...
7
+
8
+ def decode(self, *args, **kwargs): ...
9
+
10
+ @property
11
+ def mask_token(self) -> str: ...
12
+
13
+ @property
14
+ def mask_token_id(self) -> int: ...
15
+
16
+ @property
17
+ def bos_token(self) -> str: ...
18
+
19
+ @property
20
+ def bos_token_id(self) -> int: ...
21
+
22
+ @property
23
+ def eos_token(self) -> str: ...
24
+
25
+ @property
26
+ def eos_token_id(self) -> int: ...
27
+
28
+ @property
29
+ def pad_token(self) -> str: ...
30
+
31
+ @property
32
+ def pad_token_id(self) -> int: ...
33
+
34
+ @property
35
+ def chain_break_token(self) -> str: ...
36
+
37
+ @property
38
+ def chain_break_token_id(self) -> int: ...
39
+
40
+ @property
41
+ def all_token_ids(self): ...
42
+
43
+ @property
44
+ def special_token_ids(self): ...
Dyna-1/esm/utils/constants/api.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ MAX_TOPK_SEQUENCE = 32
2
+ MAX_TOPK_STRUCTURE = MAX_TOPK_SEQUENCE
3
+ MAX_TOPK_SECONDARY_STRUCTURE = MAX_TOPK_SEQUENCE
4
+ MAX_TOPK_SASA = MAX_TOPK_SEQUENCE
5
+ MAX_TOPK_FUNCTION = MAX_TOPK_SEQUENCE
Dyna-1/esm/utils/constants/esm3.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import cache
3
+ from pathlib import Path
4
+
5
+ from huggingface_hub import snapshot_download
6
+
7
+ SEQUENCE_BOS_TOKEN = 0
8
+ SEQUENCE_PAD_TOKEN = 1
9
+ SEQUENCE_EOS_TOKEN = 2
10
+ SEQUENCE_CHAINBREAK_TOKEN = 31
11
+ SEQUENCE_MASK_TOKEN = 32
12
+
13
+ VQVAE_CODEBOOK_SIZE = 4096
14
+ VQVAE_SPECIAL_TOKENS = {
15
+ "MASK": VQVAE_CODEBOOK_SIZE,
16
+ "EOS": VQVAE_CODEBOOK_SIZE + 1,
17
+ "BOS": VQVAE_CODEBOOK_SIZE + 2,
18
+ "PAD": VQVAE_CODEBOOK_SIZE + 3,
19
+ "CHAINBREAK": VQVAE_CODEBOOK_SIZE + 4,
20
+ }
21
+ VQVAE_DIRECTION_LOSS_BINS = 16
22
+ VQVAE_PAE_BINS = 64
23
+ VQVAE_MAX_PAE_BIN = 31.0
24
+ VQVAE_PLDDT_BINS = 50
25
+
26
+ STRUCTURE_MASK_TOKEN = VQVAE_SPECIAL_TOKENS["MASK"]
27
+ STRUCTURE_BOS_TOKEN = VQVAE_SPECIAL_TOKENS["BOS"]
28
+ STRUCTURE_EOS_TOKEN = VQVAE_SPECIAL_TOKENS["EOS"]
29
+ STRUCTURE_PAD_TOKEN = VQVAE_SPECIAL_TOKENS["PAD"]
30
+ STRUCTURE_CHAINBREAK_TOKEN = VQVAE_SPECIAL_TOKENS["CHAINBREAK"]
31
+ STRUCTURE_UNDEFINED_TOKEN = 955
32
+
33
+ SASA_PAD_TOKEN = 0
34
+
35
+ SS8_PAD_TOKEN = 0
36
+
37
+ INTERPRO_PAD_TOKEN = 0
38
+
39
+ RESIDUE_PAD_TOKEN = 0
40
+
41
+ CHAIN_BREAK_STR = "|"
42
+
43
+ SEQUENCE_BOS_STR = "<cls>"
44
+ SEQUENCE_EOS_STR = "<eos>"
45
+
46
+ MASK_STR_SHORT = "_"
47
+ SEQUENCE_MASK_STR = "<mask>"
48
+ SASA_MASK_STR = "<unk>"
49
+ SS8_MASK_STR = "<unk>"
50
+
51
+ # fmt: off
52
+ SEQUENCE_VOCAB = [
53
+ "<cls>", "<pad>", "<eos>", "<unk>",
54
+ "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K",
55
+ "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z",
56
+ "O", ".", "-", "|",
57
+ "<mask>",
58
+ ]
59
+ # fmt: on
60
+
61
+ SSE_8CLASS_VOCAB = "GHITEBSC"
62
+ SSE_3CLASS_VOCAB = "HEC"
63
+ SSE_8CLASS_TO_3CLASS_MAP = {
64
+ "G": "H",
65
+ "H": "H",
66
+ "I": "H",
67
+ "T": "C",
68
+ "E": "E",
69
+ "B": "E",
70
+ "S": "C",
71
+ "C": "C",
72
+ }
73
+
74
+ SASA_DISCRETIZATION_BOUNDARIES = [
75
+ 0.8,
76
+ 4.0,
77
+ 9.6,
78
+ 16.4,
79
+ 24.5,
80
+ 32.9,
81
+ 42.0,
82
+ 51.5,
83
+ 61.2,
84
+ 70.9,
85
+ 81.6,
86
+ 93.3,
87
+ 107.2,
88
+ 125.4,
89
+ 151.4,
90
+ ]
91
+
92
+ MAX_RESIDUE_ANNOTATIONS = 16
93
+
94
+
95
+ TFIDF_VECTOR_SIZE = 58641
96
+
97
+
98
+ @staticmethod
99
+ @cache
100
+ def data_root(model: str):
101
+ if "INFRA_PROVIDER" in os.environ:
102
+ return Path("")
103
+ # Try to download from hugginface if it doesn't exist
104
+ if model.startswith("esm3"):
105
+ path = Path(snapshot_download(repo_id="EvolutionaryScale/esm3-sm-open-v1"))
106
+ elif model.startswith("esmc-300"):
107
+ path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-300m-2024-12"))
108
+ elif model.startswith("esmc-600"):
109
+ path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-600m-2024-12"))
110
+ else:
111
+ raise ValueError(f"{model=} is an invalid model name.")
112
+ return path
113
+
114
+
115
+ IN_REPO_DATA_FOLDER = Path(__file__).parents[2] / "data"
116
+
117
+ INTERPRO_ENTRY = IN_REPO_DATA_FOLDER / "entry_list_safety_29026.list"
118
+ INTERPRO_HIERARCHY = IN_REPO_DATA_FOLDER / "ParentChildTreeFile.txt"
119
+ INTERPRO2GO = IN_REPO_DATA_FOLDER / "ParentChildTreeFile.txt"
120
+ INTERPRO_2ID = "data/tag_dict_4_safety_filtered.json"
121
+
122
+ LSH_TABLE_PATHS = {"8bit": "data/hyperplanes_8bit_58641.npz"}
123
+
124
+ KEYWORDS_VOCABULARY = (
125
+ IN_REPO_DATA_FOLDER / "keyword_vocabulary_safety_filtered_58641.txt"
126
+ )
127
+ KEYWORDS_IDF = IN_REPO_DATA_FOLDER / "keyword_idf_safety_filtered_58641.npy"
128
+
129
+ RESID_CSV = "data/uniref90_and_mgnify90_residue_annotations_gt_1k_proteins.csv"
130
+ INTERPRO2KEYWORDS = IN_REPO_DATA_FOLDER / "interpro_29026_to_keywords_58641.csv"
Dyna-1/esm/utils/constants/models.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model names
2
+ ESM3_OPEN_SMALL = "esm3_sm_open_v1"
3
+ ESM3_OPEN_SMALL_ALIAS_1 = "esm3-open-2024-03"
4
+ ESM3_OPEN_SMALL_ALIAS_2 = "esm3-sm-open-v1"
5
+ ESM3_OPEN_SMALL_ALIAS_3 = "esm3-open"
6
+ ESM3_STRUCTURE_ENCODER_V0 = "esm3_structure_encoder_v0"
7
+ ESM3_STRUCTURE_DECODER_V0 = "esm3_structure_decoder_v0"
8
+ ESM3_FUNCTION_DECODER_V0 = "esm3_function_decoder_v0"
9
+ ESMC_600M = "esmc_600m"
10
+ ESMC_300M = "esmc_300m"
11
+
12
+
13
+ def model_is_locally_supported(x: str):
14
+ return x in {
15
+ ESM3_OPEN_SMALL,
16
+ ESM3_OPEN_SMALL_ALIAS_1,
17
+ ESM3_OPEN_SMALL_ALIAS_2,
18
+ ESM3_OPEN_SMALL_ALIAS_3,
19
+ }
20
+
21
+
22
+ def normalize_model_name(x: str):
23
+ if x in {ESM3_OPEN_SMALL_ALIAS_1, ESM3_OPEN_SMALL_ALIAS_2, ESM3_OPEN_SMALL_ALIAS_3}:
24
+ return ESM3_OPEN_SMALL
25
+ return x
Dyna-1/esm/utils/constants/physics.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ BB_COORDINATES = [
2
+ [0.5256, 1.3612, 0.0000],
3
+ [0.0000, 0.0000, 0.0000],
4
+ [-1.5251, 0.0000, 0.0000],
5
+ ]
Dyna-1/esm/utils/decoding.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import cast
3
+
4
+ import attr
5
+ import torch
6
+
7
+ from esm.models.function_decoder import FunctionTokenDecoder
8
+ from esm.models.vqvae import StructureTokenDecoder
9
+ from esm.sdk.api import ESMProtein, ESMProteinTensor
10
+ from esm.tokenization import TokenizerCollectionProtocol
11
+ from esm.tokenization.function_tokenizer import (
12
+ InterProQuantizedTokenizer,
13
+ )
14
+ from esm.tokenization.residue_tokenizer import (
15
+ ResidueAnnotationsTokenizer,
16
+ )
17
+ from esm.tokenization.sasa_tokenizer import (
18
+ SASADiscretizingTokenizer,
19
+ )
20
+ from esm.tokenization.sequence_tokenizer import (
21
+ EsmSequenceTokenizer,
22
+ )
23
+ from esm.tokenization.ss_tokenizer import (
24
+ SecondaryStructureTokenizer,
25
+ )
26
+ from esm.tokenization.structure_tokenizer import (
27
+ StructureTokenizer,
28
+ )
29
+ from esm.tokenization.tokenizer_base import EsmTokenizerBase
30
+ from esm.utils.constants import esm3 as C
31
+ from esm.utils.function.encode_decode import (
32
+ decode_function_tokens,
33
+ decode_residue_annotation_tokens,
34
+ )
35
+ from esm.utils.misc import maybe_list
36
+ from esm.utils.structure.protein_chain import ProteinChain
37
+ from esm.utils.types import FunctionAnnotation
38
+
39
+
40
+ def decode_protein_tensor(
41
+ input: ESMProteinTensor,
42
+ tokenizers: TokenizerCollectionProtocol,
43
+ structure_token_decoder: StructureTokenDecoder,
44
+ function_token_decoder: FunctionTokenDecoder | None = None,
45
+ ) -> ESMProtein:
46
+ input = attr.evolve(input) # Make a copy
47
+
48
+ sequence = None
49
+ secondary_structure = None
50
+ sasa = None
51
+ function_annotations = []
52
+
53
+ coordinates = None
54
+
55
+ # If all pad tokens, set to None
56
+ for track in attr.fields(ESMProteinTensor):
57
+ tokens: torch.Tensor | None = getattr(input, track.name)
58
+ if track.name == "coordinates" or track.name == "potential_sequence_of_concern":
59
+ continue
60
+ if tokens is not None:
61
+ tokens = tokens[1:-1] # Remove BOS and EOS tokens
62
+ tokens = tokens.flatten() # For multi-track tensors
63
+ track_tokenizer = getattr(tokenizers, track.name)
64
+ if torch.all(tokens == track_tokenizer.pad_token_id):
65
+ setattr(input, track.name, None)
66
+ # If structure track has any mask tokens, do not decode.
67
+ if track.name == "structure" and torch.any(
68
+ tokens == track_tokenizer.mask_token_id
69
+ ):
70
+ setattr(input, track.name, None)
71
+
72
+ if input.sequence is not None:
73
+ sequence = decode_sequence(input.sequence, tokenizers.sequence)
74
+
75
+ plddt, ptm = None, None
76
+ if input.structure is not None:
77
+ # Note: We give priority to the structure tokens over the coordinates when decoding
78
+ coordinates, plddt, ptm = decode_structure(
79
+ structure_tokens=input.structure,
80
+ structure_decoder=structure_token_decoder,
81
+ structure_tokenizer=tokenizers.structure,
82
+ sequence=sequence,
83
+ )
84
+ elif input.coordinates is not None:
85
+ coordinates = input.coordinates[1:-1, ...]
86
+
87
+ if input.secondary_structure is not None:
88
+ secondary_structure = decode_secondary_structure(
89
+ input.secondary_structure, tokenizers.secondary_structure
90
+ )
91
+ if input.sasa is not None:
92
+ sasa = decode_sasa(input.sasa, tokenizers.sasa)
93
+ if input.function is not None:
94
+ if function_token_decoder is None:
95
+ raise ValueError(
96
+ "Cannot decode function annotations without a function token decoder"
97
+ )
98
+ function_track_annotations = decode_function_annotations(
99
+ input.function,
100
+ function_token_decoder=function_token_decoder,
101
+ function_tokenizer=tokenizers.function,
102
+ )
103
+ function_annotations.extend(function_track_annotations)
104
+ if input.residue_annotations is not None:
105
+ residue_annotations = decode_residue_annotations(
106
+ input.residue_annotations, tokenizers.residue_annotations
107
+ )
108
+ function_annotations.extend(residue_annotations)
109
+
110
+ return ESMProtein(
111
+ sequence=sequence,
112
+ secondary_structure=secondary_structure,
113
+ sasa=sasa, # type: ignore
114
+ function_annotations=function_annotations if function_annotations else None,
115
+ coordinates=coordinates,
116
+ plddt=plddt,
117
+ ptm=ptm,
118
+ potential_sequence_of_concern=input.potential_sequence_of_concern,
119
+ )
120
+
121
+
122
+ def _bos_eos_warn(msg: str, tensor: torch.Tensor, tok: EsmTokenizerBase):
123
+ if tensor[0] != tok.bos_token_id:
124
+ warnings.warn(
125
+ f"{msg} does not start with BOS token, token is ignored. BOS={tok.bos_token_id} vs {tensor}"
126
+ )
127
+ if tensor[-1] != tok.eos_token_id:
128
+ warnings.warn(
129
+ f"{msg} does not end with EOS token, token is ignored. EOS='{tok.eos_token_id}': {tensor}"
130
+ )
131
+
132
+
133
+ def decode_sequence(
134
+ sequence_tokens: torch.Tensor, sequence_tokenizer: EsmSequenceTokenizer, **kwargs
135
+ ) -> str:
136
+ _bos_eos_warn("Sequence", sequence_tokens, sequence_tokenizer)
137
+ sequence = sequence_tokenizer.decode(sequence_tokens, **kwargs)
138
+ sequence = sequence.replace(" ", "")
139
+ sequence = sequence.replace(sequence_tokenizer.mask_token, C.MASK_STR_SHORT)
140
+ sequence = sequence.replace(sequence_tokenizer.cls_token, "")
141
+ sequence = sequence.replace(sequence_tokenizer.eos_token, "")
142
+
143
+ return sequence
144
+
145
+
146
+ def decode_structure(
147
+ structure_tokens: torch.Tensor,
148
+ structure_decoder: StructureTokenDecoder,
149
+ structure_tokenizer: StructureTokenizer,
150
+ sequence: str | None = None,
151
+ ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
152
+ is_singleton = len(structure_tokens.size()) == 1
153
+ if is_singleton:
154
+ structure_tokens = structure_tokens.unsqueeze(0)
155
+ else:
156
+ raise ValueError(
157
+ f"Only one structure can be decoded at a time, got structure tokens of shape {structure_tokens.size()}"
158
+ )
159
+ _bos_eos_warn("Structure", structure_tokens[0], structure_tokenizer)
160
+
161
+ decoder_output = structure_decoder.decode(structure_tokens)
162
+ bb_coords: torch.Tensor = decoder_output["bb_pred"][
163
+ 0, 1:-1, ...
164
+ ] # Remove BOS and EOS tokens
165
+ bb_coords = bb_coords.detach().cpu()
166
+
167
+ if "plddt" in decoder_output:
168
+ plddt = decoder_output["plddt"][0, 1:-1]
169
+ plddt = plddt.detach().cpu()
170
+ else:
171
+ plddt = None
172
+
173
+ if "ptm" in decoder_output:
174
+ ptm = decoder_output["ptm"]
175
+ else:
176
+ ptm = None
177
+
178
+ chain = ProteinChain.from_backbone_atom_coordinates(bb_coords, sequence=sequence)
179
+ chain = chain.infer_oxygen()
180
+ return torch.tensor(chain.atom37_positions), plddt, ptm
181
+
182
+
183
+ def decode_secondary_structure(
184
+ secondary_structure_tokens: torch.Tensor, ss_tokenizer: SecondaryStructureTokenizer
185
+ ) -> str:
186
+ _bos_eos_warn("Secondary structure", secondary_structure_tokens, ss_tokenizer)
187
+ secondary_structure_tokens = secondary_structure_tokens[1:-1]
188
+ secondary_structure = ss_tokenizer.decode(secondary_structure_tokens)
189
+ return secondary_structure
190
+
191
+
192
+ def decode_sasa(
193
+ sasa_tokens: torch.Tensor, sasa_tokenizer: SASADiscretizingTokenizer
194
+ ) -> list[float]:
195
+ if sasa_tokens[0] != 0:
196
+ raise ValueError("SASA does not start with 0 corresponding to BOS token")
197
+ if sasa_tokens[-1] != 0:
198
+ raise ValueError("SASA does not end with 0 corresponding to EOS token")
199
+ sasa_tokens = sasa_tokens[1:-1]
200
+ if sasa_tokens.dtype in [
201
+ torch.int8,
202
+ torch.int16,
203
+ torch.int32,
204
+ torch.int64,
205
+ torch.long,
206
+ ]:
207
+ # Decode if int
208
+ # handles turning NaN's into None's
209
+ sasa = sasa_tokenizer.decode_float(sasa_tokens)
210
+ else:
211
+ # If already float, just convert to list
212
+ sasa = cast(list[float], maybe_list(sasa_tokens, convert_nan_to_none=True))
213
+
214
+ return sasa
215
+
216
+
217
+ def decode_function_annotations(
218
+ function_annotation_tokens: torch.Tensor,
219
+ function_token_decoder: FunctionTokenDecoder,
220
+ function_tokenizer: InterProQuantizedTokenizer,
221
+ **kwargs,
222
+ ) -> list[FunctionAnnotation]:
223
+ # No need to check for BOS/EOS as function annotations are not affected
224
+
225
+ function_annotations = decode_function_tokens(
226
+ function_annotation_tokens,
227
+ function_token_decoder=function_token_decoder,
228
+ function_tokens_tokenizer=function_tokenizer,
229
+ **kwargs,
230
+ )
231
+ return function_annotations
232
+
233
+
234
+ def decode_residue_annotations(
235
+ residue_annotation_tokens: torch.Tensor,
236
+ residue_annotation_decoder: ResidueAnnotationsTokenizer,
237
+ ) -> list[FunctionAnnotation]:
238
+ # No need to check for BOS/EOS as function annotations are not affected
239
+
240
+ residue_annotations = decode_residue_annotation_tokens(
241
+ residue_annotations_token_ids=residue_annotation_tokens,
242
+ residue_annotations_tokenizer=residue_annotation_decoder,
243
+ )
244
+ return residue_annotations
Dyna-1/esm/utils/encoding.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from esm.models.vqvae import StructureTokenEncoder
7
+ from esm.tokenization.function_tokenizer import (
8
+ InterProQuantizedTokenizer as EsmFunctionTokenizer,
9
+ )
10
+
11
+ from esm.tokenization.residue_tokenizer import (
12
+ ResidueAnnotationsTokenizer,
13
+ )
14
+ from esm.tokenization.sasa_tokenizer import (
15
+ SASADiscretizingTokenizer,
16
+ )
17
+ from esm.tokenization.sequence_tokenizer import (
18
+ EsmSequenceTokenizer,
19
+ )
20
+ from esm.tokenization.ss_tokenizer import (
21
+ SecondaryStructureTokenizer,
22
+ )
23
+ from esm.tokenization.structure_tokenizer import (
24
+ StructureTokenizer,
25
+ )
26
+ from esm.utils.constants import esm3 as C
27
+ from esm.utils.function.encode_decode import (
28
+ encode_function_annotations,
29
+ )
30
+ from esm.utils.structure.protein_chain import ProteinChain
31
+ from esm.utils.types import FunctionAnnotation
32
+
33
+
34
+ # Raw Defaults
35
+ def get_default_sequence(sequence_length: int) -> str:
36
+ return C.MASK_STR_SHORT * sequence_length
37
+
38
+
39
+ def get_default_secondary_structure(sequence_length: int) -> str:
40
+ return C.MASK_STR_SHORT * sequence_length
41
+
42
+
43
+ def get_default_sasa(sequence_length: int) -> Sequence[float | str | None]:
44
+ return [None] * sequence_length
45
+
46
+
47
+ # Tokenization
48
+ def tokenize_sequence(
49
+ sequence: str,
50
+ sequence_tokenizer: EsmSequenceTokenizer,
51
+ add_special_tokens: bool = True,
52
+ ) -> torch.Tensor:
53
+ sequence = sequence.replace(C.MASK_STR_SHORT, sequence_tokenizer.mask_token)
54
+ sequence_tokens = sequence_tokenizer.encode(
55
+ sequence, add_special_tokens=add_special_tokens
56
+ )
57
+ sequence_tokens = torch.tensor(sequence_tokens, dtype=torch.int64)
58
+ return sequence_tokens
59
+
60
+
61
+ def tokenize_structure(
62
+ coordinates: torch.Tensor,
63
+ structure_encoder: StructureTokenEncoder,
64
+ structure_tokenizer: StructureTokenizer,
65
+ reference_sequence: str = "",
66
+ add_special_tokens: bool = True,
67
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
68
+ device = next(structure_encoder.parameters()).device
69
+ chain = ProteinChain.from_atom37(
70
+ coordinates, sequence=reference_sequence if reference_sequence else None
71
+ )
72
+
73
+ # Setup padding
74
+ if reference_sequence and len(reference_sequence) != coordinates.size(0):
75
+ raise ValueError(
76
+ f"Reference sequence length ({len(reference_sequence)}) does not match the number of residues in the coordinates ({coordinates.size(0)})"
77
+ )
78
+
79
+ left_pad = 0
80
+ right_pad = 0
81
+
82
+ if add_special_tokens:
83
+ left_pad += 1 # Add space for BOS token
84
+ right_pad += 1 # Add space for EOS token
85
+
86
+ coordinates, plddt, residue_index = chain.to_structure_encoder_inputs()
87
+ coordinates = coordinates.to(device) # (1, L, 37, 3)
88
+ plddt = plddt.to(device) # (1, L)
89
+ residue_index = residue_index.to(device) # (1, L)
90
+ _, structure_tokens = structure_encoder.encode(
91
+ coordinates, residue_index=residue_index
92
+ )
93
+ coordinates = torch.squeeze(coordinates, dim=0) # (L, 37, 3) # type: ignore
94
+ plddt = torch.squeeze(plddt, dim=0) # (L,) # type: ignore
95
+ structure_tokens = torch.squeeze(structure_tokens, dim=0) # (L,) # type: ignore
96
+
97
+ # Add space for BOS and EOS tokens
98
+ if add_special_tokens:
99
+ coordinates = F.pad(
100
+ coordinates, (0, 0, 0, 0, left_pad, right_pad), value=torch.inf
101
+ )
102
+ plddt = F.pad(plddt, (left_pad, right_pad), value=0)
103
+ structure_tokens = F.pad(
104
+ structure_tokens,
105
+ (left_pad, right_pad),
106
+ value=structure_tokenizer.mask_token_id,
107
+ )
108
+ structure_tokens[0] = structure_tokenizer.bos_token_id
109
+ structure_tokens[-1] = structure_tokenizer.eos_token_id
110
+ return coordinates, plddt, structure_tokens
111
+
112
+
113
+ def tokenize_secondary_structure(
114
+ secondary_structure: str | Sequence[str],
115
+ secondary_structure_tokenizer: SecondaryStructureTokenizer,
116
+ add_special_tokens: bool = True,
117
+ ) -> torch.Tensor:
118
+ if isinstance(secondary_structure, str):
119
+ # Ensure only one char per token
120
+ secondary_structure = secondary_structure.replace(
121
+ secondary_structure_tokenizer.mask_token, C.MASK_STR_SHORT
122
+ )
123
+
124
+ # Input as list of chars
125
+ secondary_structure = [char for char in secondary_structure]
126
+
127
+ # Use tokenizer's mask token
128
+ secondary_structure = [
129
+ secondary_structure_tokenizer.mask_token if char == C.MASK_STR_SHORT else char
130
+ for char in secondary_structure
131
+ ]
132
+
133
+ secondary_structure_tokens = secondary_structure_tokenizer.encode(
134
+ secondary_structure, add_special_tokens=add_special_tokens
135
+ )
136
+ return secondary_structure_tokens
137
+
138
+
139
+ def tokenize_sasa(
140
+ sasa: Sequence[float | str | None],
141
+ sasa_tokenizer: SASADiscretizingTokenizer,
142
+ add_special_tokens: bool = True,
143
+ ):
144
+ sasa_tokens = sasa_tokenizer.encode(
145
+ [sasa_tokenizer.mask_token if value is None else value for value in sasa],
146
+ add_special_tokens=add_special_tokens,
147
+ )
148
+ return sasa_tokens
149
+
150
+
151
+ def tokenize_function_annotations(
152
+ function_annotations: Sequence[FunctionAnnotation],
153
+ reference_sequence: str,
154
+ function_tokenizer: EsmFunctionTokenizer,
155
+ residue_annotation_tokenizer: ResidueAnnotationsTokenizer,
156
+ add_special_tokens: bool = True,
157
+ ) -> tuple[torch.Tensor, torch.Tensor]:
158
+ function_tokens, residue_annotation_tokens = encode_function_annotations(
159
+ sequence=reference_sequence,
160
+ function_annotations=function_annotations,
161
+ function_tokens_tokenizer=function_tokenizer,
162
+ residue_annotations_tokenizer=residue_annotation_tokenizer,
163
+ add_special_tokens=add_special_tokens,
164
+ )
165
+ return function_tokens, residue_annotation_tokens
166
+
167
+
168
+
169
+
170
+ # Tokenized Defaults
171
+ def get_default_sequence_tokens(
172
+ sequence_length: int, sequence_tokenizer: EsmSequenceTokenizer
173
+ ) -> torch.Tensor:
174
+ assert sequence_tokenizer.mask_token_id is not None
175
+ assert sequence_tokenizer.bos_token_id is not None
176
+ assert sequence_tokenizer.eos_token_id is not None
177
+
178
+ sequence_tokens = torch.full(
179
+ (sequence_length + 2,), sequence_tokenizer.mask_token_id
180
+ )
181
+ sequence_tokens[0] = sequence_tokenizer.bos_token_id
182
+ sequence_tokens[-1] = sequence_tokenizer.eos_token_id
183
+
184
+ return sequence_tokens
185
+
186
+
187
+ def get_default_structure_tokens(
188
+ sequence_length: int, structure_tokenizer: StructureTokenizer
189
+ ) -> torch.Tensor:
190
+ structure_tokens = (
191
+ torch.ones((sequence_length + 2,), dtype=torch.int64)
192
+ * structure_tokenizer.mask_token_id
193
+ )
194
+ # Always include BOS and EOS tokens
195
+ structure_tokens[0] = structure_tokenizer.bos_token_id
196
+ structure_tokens[-1] = structure_tokenizer.eos_token_id
197
+ return structure_tokens
198
+
199
+
200
+ def get_default_secondary_structure_tokens(
201
+ sequence_length: int, secondary_structure_tokenizer: SecondaryStructureTokenizer
202
+ ) -> torch.Tensor:
203
+ ss8_tokens = torch.full(
204
+ (sequence_length + 2,), secondary_structure_tokenizer.mask_token_id
205
+ )
206
+ ss8_tokens[0] = secondary_structure_tokenizer.bos_token_id
207
+ ss8_tokens[-1] = secondary_structure_tokenizer.eos_token_id
208
+
209
+ return ss8_tokens
210
+
211
+
212
+ def get_default_sasa_tokens(
213
+ sequence_length: int, sasa_tokenizer: SASADiscretizingTokenizer
214
+ ) -> torch.Tensor:
215
+ sasa_tokens = torch.full((sequence_length + 2,), sasa_tokenizer.mask_token_id)
216
+ sasa_tokens[0] = sasa_tokenizer.bos_token_id
217
+ sasa_tokens[-1] = sasa_tokenizer.eos_token_id
218
+ return sasa_tokens
219
+
220
+
221
+ def get_default_function_tokens(
222
+ sequence_length: int, function_tokenizer: EsmFunctionTokenizer
223
+ ) -> torch.Tensor:
224
+ function_tokens = (
225
+ torch.ones((sequence_length + 2, function_tokenizer.depth), dtype=torch.int64)
226
+ * function_tokenizer.pad_token_id
227
+ )
228
+ # Always include BOS and EOS tokens
229
+ function_tokens[0] = function_tokenizer.bos_token_id
230
+ function_tokens[-1] = function_tokenizer.eos_token_id
231
+ return function_tokens
232
+
233
+
234
+ def get_default_residue_annotation_tokens(
235
+ sequence_length: int, residue_annotation_tokenizer: ResidueAnnotationsTokenizer
236
+ ) -> torch.Tensor:
237
+ residue_annotation_tokens = (
238
+ torch.ones((sequence_length + 2, C.MAX_RESIDUE_ANNOTATIONS), dtype=torch.int64)
239
+ * residue_annotation_tokenizer.pad_token_id
240
+ )
241
+ # Always include BOS and EOS tokens
242
+ residue_annotation_tokens[0] = residue_annotation_tokenizer.bos_token_id
243
+ residue_annotation_tokens[-1] = residue_annotation_tokenizer.eos_token_id
244
+ return residue_annotation_tokens
245
+
246
+
Dyna-1/esm/utils/function/encode_decode.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Sequence
3
+
4
+ import torch
5
+
6
+ from esm.models.function_decoder import (
7
+ FunctionTokenDecoder,
8
+ merge_annotations,
9
+ )
10
+ from esm.tokenization.function_tokenizer import (
11
+ InterProQuantizedTokenizer,
12
+ )
13
+ from esm.tokenization.residue_tokenizer import (
14
+ ResidueAnnotationsTokenizer,
15
+ )
16
+ from esm.utils.constants import esm3 as C
17
+ from esm.utils.types import FunctionAnnotation
18
+
19
+
20
+ def encode_function_annotations(
21
+ sequence: str,
22
+ function_annotations: Sequence[FunctionAnnotation],
23
+ function_tokens_tokenizer: InterProQuantizedTokenizer,
24
+ residue_annotations_tokenizer: ResidueAnnotationsTokenizer,
25
+ add_special_tokens: bool = True,
26
+ ) -> tuple[torch.Tensor, torch.Tensor]:
27
+ assert isinstance(
28
+ residue_annotations_tokenizer, ResidueAnnotationsTokenizer
29
+ ), "residue_annotations_tokenizer must be of type ResidueAnnotationsTokenizer"
30
+
31
+ # Split the user's annotations by type
32
+ ft_annotations: list[FunctionAnnotation] = []
33
+ ra_annotations: list[FunctionAnnotation] = []
34
+ for fa in function_annotations:
35
+ assert (
36
+ 1 <= fa.start <= fa.end <= len(sequence)
37
+ ), f"Invalid (start, end) in function annotation {fa}. Indices 1-indexed and [inclusive, inclusive]"
38
+
39
+ supported_label = False
40
+
41
+ # Is it an InterPro label?
42
+ if match := re.search(r"IPR\d+", fa.label):
43
+ if match.group() in function_tokens_tokenizer.interpro_to_index:
44
+ ft_annotations.append(fa)
45
+ supported_label = True
46
+
47
+ # Is it a function keyword?
48
+ if fa.label in function_tokens_tokenizer._tfidf.vocab_to_index:
49
+ ft_annotations.append(fa)
50
+ supported_label = True
51
+
52
+ # Is it a residue annotation?
53
+ if fa.label in residue_annotations_tokenizer._labels:
54
+ ra_annotations.append(fa)
55
+ supported_label = True
56
+
57
+ if not supported_label:
58
+ raise ValueError(f"Unknown label in FunctionAnnotation: {fa.label}")
59
+
60
+ # Convert function token FunctionAnnotations -> Tensor
61
+ function_tokens = function_tokens_tokenizer.tokenize(
62
+ annotations=ft_annotations, seqlen=len(sequence)
63
+ )
64
+ function_token_ids = function_tokens_tokenizer.encode(
65
+ function_tokens, add_special_tokens=add_special_tokens
66
+ )
67
+
68
+ # Convert residue annotation FunctionAnnotations -> Tensor
69
+ if ra_annotations:
70
+ descriptions, starts, ends = zip(
71
+ *[(anot.label, anot.start, anot.end) for anot in ra_annotations]
72
+ )
73
+ else:
74
+ descriptions = starts = ends = None
75
+ ra_tokens = residue_annotations_tokenizer.tokenize(
76
+ {
77
+ "interpro_site_descriptions": descriptions,
78
+ "interpro_site_starts": starts,
79
+ "interpro_site_ends": ends,
80
+ },
81
+ sequence=sequence,
82
+ fail_on_mismatch=True,
83
+ )
84
+ residue_annotation_ids = residue_annotations_tokenizer.encode(
85
+ ra_tokens, add_special_tokens=add_special_tokens
86
+ )
87
+
88
+ return function_token_ids, residue_annotation_ids
89
+
90
+
91
+ def decode_function_tokens(
92
+ function_token_ids: torch.Tensor,
93
+ function_token_decoder: FunctionTokenDecoder,
94
+ function_tokens_tokenizer: InterProQuantizedTokenizer,
95
+ decoder_annotation_threshold: float = 0.1,
96
+ annotation_min_length: int | None = 5,
97
+ annotation_gap_merge_max: int | None = 3,
98
+ ) -> list[FunctionAnnotation]:
99
+ """Decodes model prediction logits into function predictions.
100
+
101
+ Merges function token and residue annotation predictions into a single
102
+ set of FunctionAnnotation predictions.
103
+
104
+ Args:
105
+ function_token_ids: Tensor <float>[length, depth] of
106
+ function token ids.
107
+ residue_annotation_logits: Tensor <float>[length, RA-vocab] of residue
108
+ annotation binary classification logits.
109
+ function_tokens_tokenizer: InterPro annotation tokenizer.
110
+ residue_annotation_threshold: tokenizer of residue annotations.
111
+ residue_annotation_threshold: predicted probability threshold for emitting
112
+ a predicted residue annotation.
113
+ Returns:
114
+ Predicted function annotations merged from both predictions.
115
+ """
116
+ assert (
117
+ function_token_ids.ndim == 2
118
+ ), "function_token_ids must be of shape (length, depth)"
119
+
120
+ annotations: list[FunctionAnnotation] = []
121
+
122
+ # Function Annotations from predicted function tokens.
123
+ decoded = function_token_decoder.decode(
124
+ function_token_ids,
125
+ tokenizer=function_tokens_tokenizer,
126
+ annotation_threshold=decoder_annotation_threshold,
127
+ annotation_min_length=annotation_min_length,
128
+ annotation_gap_merge_max=annotation_gap_merge_max,
129
+ )
130
+
131
+ # Convert predicted InterPro annotation to FunctionAnnotation.
132
+ annotations.extend(decoded["function_keywords"])
133
+ for annotation in decoded["interpro_annotations"]:
134
+ annotation: FunctionAnnotation
135
+ label = function_tokens_tokenizer.format_annotation(annotation)
136
+ annotations.append(
137
+ FunctionAnnotation(label=label, start=annotation.start, end=annotation.end)
138
+ )
139
+
140
+ return annotations
141
+
142
+
143
+ def decode_residue_annotation_tokens(
144
+ residue_annotations_token_ids: torch.Tensor,
145
+ residue_annotations_tokenizer: ResidueAnnotationsTokenizer,
146
+ annotation_min_length: int | None = 5,
147
+ annotation_gap_merge_max: int | None = 3,
148
+ ) -> list[FunctionAnnotation]:
149
+ """Decodes residue annotation tokens into FunctionAnnotations.
150
+
151
+ Args:
152
+ tokens: Tensor <int>[length, MAX_RESIDUE_ANNOTATIONS] of residue annotation tokens.
153
+ residue_annotations_tokenizer: Tokenizer of residue annotations.
154
+ threshold: predicted probability threshold for emitting a predicted residue
155
+ annotation.
156
+ Returns:
157
+ Predicted residue annotations.
158
+ """
159
+ assert (
160
+ residue_annotations_token_ids.ndim == 2
161
+ ), "logits must be of shape (length, MAX_RESIDUE_ANNOTATIONS)"
162
+
163
+ annotations: list[FunctionAnnotation] = []
164
+
165
+ for depth in range(0, C.MAX_RESIDUE_ANNOTATIONS):
166
+ token_ids = residue_annotations_token_ids[:, depth]
167
+ nonzero_indices = torch.nonzero(token_ids).squeeze(dim=1).cpu().numpy()
168
+ if len(nonzero_indices) == 0:
169
+ continue
170
+ for loc in nonzero_indices:
171
+ vocab_index: int = token_ids[loc].item() # type: ignore
172
+ label = residue_annotations_tokenizer.vocabulary[vocab_index]
173
+ if label not in [*residue_annotations_tokenizer.special_tokens, "<none>"]:
174
+ annotation = FunctionAnnotation(label=label, start=loc, end=loc)
175
+ annotations.append(annotation)
176
+
177
+ annotations = merge_annotations(annotations, merge_gap_max=annotation_gap_merge_max)
178
+
179
+ # Drop very small annotations.
180
+ if annotation_min_length is not None:
181
+ annotations = [
182
+ annotation
183
+ for annotation in annotations
184
+ if annotation.end - annotation.start + 1 >= annotation_min_length
185
+ ]
186
+
187
+ return annotations
Dyna-1/esm/utils/function/interpro.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utilities for interacting with InterPro."""
2
+
3
+ import itertools
4
+ import re
5
+ from dataclasses import dataclass
6
+ from enum import IntEnum, auto
7
+ from functools import cached_property
8
+
9
+ import networkx as nx
10
+ import pandas as pd
11
+ from cloudpathlib import AnyPath
12
+
13
+ from esm.utils.constants import esm3 as C
14
+ from esm.utils.types import PathLike
15
+
16
+
17
+ def parse_go_terms(text: str) -> list[str]:
18
+ """Parses GO terms from a string.
19
+
20
+ Args:
21
+ text: String containing GO terms. Example: "GO:0008309, GO:1902267" Note that GO
22
+ terms have exactly 7 digits.
23
+ Returns:
24
+ All GO terms found in the string. Example: ['GO:0008309', 'GO:1902267']
25
+ """
26
+ return re.findall(r"GO:(?:\d{7,})", text)
27
+
28
+
29
+ def _parse_interpro2go(path: PathLike) -> dict[str, list[str]]:
30
+ """Parses InterPro2GO file into map.
31
+
32
+ NOTE: this file has a very strange, non-standard format.
33
+
34
+ Args:
35
+ path: path to InterPro2GO file from: https://www.ebi.ac.uk/GOA/InterPro2GO
36
+ Returns:
37
+ Mapping from InterPro to list of associated GO terms.
38
+ """
39
+ with AnyPath(path).open("r") as f:
40
+ text = f.read()
41
+ df = pd.Series(text.split("\n"), name="line").to_frame()
42
+ df = df[~df.line.str.startswith("!")]
43
+ df["interpro_id"] = df.line.apply(lambda line: re.findall(r"IPR\d+", line))
44
+ df["go_ids"] = df.line.apply(parse_go_terms)
45
+ df = df[df.go_ids.apply(len).gt(0) & df.interpro_id.apply(len).eq(1)]
46
+ df["interpro_id"] = df["interpro_id"].apply(lambda xs: xs[0]) # type: ignore
47
+
48
+ # Group all mappints together into a single map.
49
+ df = (
50
+ df.groupby("interpro_id")["go_ids"] # type: ignore
51
+ .apply(lambda group: list(itertools.chain.from_iterable(group)))
52
+ .reset_index()
53
+ )
54
+ return dict(zip(df.interpro_id, df.go_ids)) # type: ignore
55
+
56
+
57
+ class InterProEntryType(IntEnum):
58
+ """InterPro types and representation counts:
59
+
60
+ Family 21,942
61
+ Domain 14,053
62
+ Homologous_superfamily 3,446
63
+ Conserved_site 728
64
+ Repeat 374
65
+ Active_site 133
66
+ Binding_site 75
67
+ PTM 17
68
+ """
69
+
70
+ ACTIVE_SITE = 0
71
+ BINDING_SITE = auto()
72
+ CONSERVED_SITE = auto()
73
+ DOMAIN = auto()
74
+ FAMILY = auto()
75
+ HOMOLOGOUS_SUPERFAMILY = auto()
76
+ PTM = auto()
77
+ REPEAT = auto()
78
+ UNKNOWN = auto()
79
+
80
+
81
+ @dataclass
82
+ class InterProEntry:
83
+ """Represents an InterPro entry."""
84
+
85
+ id: str # Example: IPR000006
86
+ type: InterProEntryType
87
+ name: str # Example: "Metallothionein, vertebrate"
88
+ description: str | None = None
89
+
90
+
91
+ class InterPro:
92
+ """Convenience class interacting with InterPro ontology/data."""
93
+
94
+ def __init__(
95
+ self,
96
+ entries_path: PathLike | None = None,
97
+ hierarchy_path: PathLike | None = None,
98
+ interpro2go_path: PathLike | None = None,
99
+ ):
100
+ """Constructs interface to query InterPro entries."""
101
+
102
+ def default(x, d):
103
+ return x if x is not None else d
104
+
105
+ self.entries_path = default(entries_path, C.INTERPRO_ENTRY)
106
+ self.hierarchy_graph_path = default(hierarchy_path, C.INTERPRO_HIERARCHY)
107
+ self.interpro2go_path = default(interpro2go_path, C.INTERPRO2GO)
108
+
109
+ @cached_property
110
+ def interpro2go(self) -> dict[str, list[str]]:
111
+ """Reads the InterPro to GO term mapping."""
112
+ assert self.interpro2go_path is not None
113
+ return _parse_interpro2go(self.interpro2go_path)
114
+
115
+ @cached_property
116
+ def entries_frame(self) -> pd.DataFrame:
117
+ """Loads full InterPro entry set as a DataFrame.
118
+
119
+ Colums are
120
+ - "id": str interpro accession /id as
121
+ - "type": InterProEntryType representing the type of annotation.
122
+ - "name": Short name of the entry.
123
+ """
124
+ with AnyPath(self.entries_path).open("r") as f:
125
+ df = pd.read_csv(f, sep="\t")
126
+ assert all(
127
+ col in df.columns for col in ["ENTRY_AC", "ENTRY_TYPE", "ENTRY_NAME"]
128
+ )
129
+ df.rename(
130
+ columns={"ENTRY_AC": "id", "ENTRY_TYPE": "type", "ENTRY_NAME": "name"},
131
+ inplace=True,
132
+ )
133
+ df["type"] = df.type.str.upper().apply(
134
+ lambda type_name: InterProEntryType[type_name]
135
+ )
136
+ return df
137
+
138
+ @cached_property
139
+ def entries(self) -> dict[str, InterProEntry]:
140
+ """Returns all InterPro entries."""
141
+ return {
142
+ row.id: InterProEntry( # type: ignore
143
+ id=row.id, # type: ignore
144
+ type=row.type, # type: ignore
145
+ name=row.name, # type: ignore
146
+ )
147
+ for row in self.entries_frame.itertuples()
148
+ }
149
+
150
+ def lookup_name(self, interpro_id: str) -> str | None:
151
+ """Short name / title for an interpro id."""
152
+ if interpro_id not in self.entries:
153
+ return None
154
+ return self.entries[interpro_id].name
155
+
156
+ def lookup_entry_type(self, interpro_id: str) -> InterProEntryType:
157
+ """Looks up entry-type for an interpro id."""
158
+ if interpro_id in self.entries:
159
+ return self.entries[interpro_id].type
160
+ else:
161
+ return InterProEntryType.UNKNOWN
162
+
163
+ @cached_property
164
+ def graph(self) -> nx.DiGraph:
165
+ """Reads the InterPro hierarchy of InterPro."""
166
+ graph = nx.DiGraph()
167
+ with AnyPath(self.hierarchy_graph_path).open("r") as f:
168
+ parents = []
169
+ for line in f:
170
+ ipr = line.split("::", maxsplit=1)[0]
171
+ ipr_strip = ipr.lstrip("-")
172
+ level = (len(ipr) - len(ipr_strip)) // 2
173
+ parents = parents[:level]
174
+ graph.add_node(ipr_strip)
175
+ if parents:
176
+ graph.add_edge(ipr_strip, parents[-1])
177
+ parents.append(ipr_strip)
178
+ return graph
Dyna-1/esm/utils/function/lsh.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from cloudpathlib import AnyPath
3
+
4
+ from esm.utils.types import PathLike
5
+
6
+
7
+ class LSHTable:
8
+ def __init__(self, n_bits: int, dim: int, hyperplanes: np.ndarray | None = None):
9
+ if hyperplanes is None:
10
+ hyperplanes = np.random.randn(n_bits, dim)
11
+ hyperplanes = hyperplanes / np.linalg.norm(
12
+ hyperplanes, axis=-1, keepdims=True
13
+ )
14
+ else:
15
+ assert hyperplanes.shape == (n_bits, dim), (
16
+ hyperplanes.shape,
17
+ (n_bits, dim),
18
+ )
19
+ assert hyperplanes is not None
20
+ self.hyperplanes: np.ndarray = hyperplanes
21
+ self.values = 1 << np.arange(n_bits)
22
+
23
+ def __call__(self, array, tokenize: bool = True):
24
+ similarity = self.hyperplanes @ array.T
25
+ bits = np.where(similarity >= 0, 1, 0)
26
+ if tokenize:
27
+ tokens = bits.T @ self.values
28
+ return tokens
29
+ else:
30
+ return bits.T
31
+
32
+
33
+ class LSHTokenized:
34
+ def __init__(
35
+ self,
36
+ n_bits: int,
37
+ dim: int,
38
+ num_tables: int = 1,
39
+ filepath: PathLike | None = None,
40
+ allow_create_hyperplanes: bool = False, # set this if you want the lsh to allow creation of hyperplanes
41
+ ):
42
+ table_hyperplanes = None
43
+ if filepath is not None:
44
+ filepath = AnyPath(filepath)
45
+ if not filepath.exists():
46
+ raise FileNotFoundError(filepath)
47
+ table_hyperplanes = np.load(filepath) # type: ignore
48
+ for i in range(num_tables):
49
+ assert str(i) in table_hyperplanes, f"Missing hyperplane for table {i}"
50
+ elif not allow_create_hyperplanes:
51
+ raise RuntimeError(
52
+ "Not allowed to create hyperplanes but no filepath provided"
53
+ )
54
+
55
+ self.tables = [
56
+ LSHTable(
57
+ n_bits,
58
+ dim,
59
+ table_hyperplanes[str(i)] if table_hyperplanes is not None else None,
60
+ )
61
+ for i in range(num_tables)
62
+ ]
63
+
64
+ def write_hyperplanes(self, filepath: PathLike):
65
+ hyperplanes: dict[str, np.ndarray] = { # type: ignore
66
+ str(i): table.hyperplanes for i, table in enumerate(self.tables)
67
+ }
68
+ np.savez(filepath, **hyperplanes)
69
+
70
+ def __call__(self, array):
71
+ tokens = np.stack([table(array) for table in self.tables], 1)
72
+ return tokens
73
+
74
+
75
+ class LSHBitstream:
76
+ def __init__(
77
+ self,
78
+ n_bits: int,
79
+ dim: int,
80
+ filepath: PathLike | None = None,
81
+ allow_create_hyperplanes: bool = False, # set this if you want the lsh to allow creation of hyperplanes
82
+ ):
83
+ table_hyperplanes = None
84
+ if filepath is not None:
85
+ filepath = AnyPath(filepath)
86
+ if not filepath.exists():
87
+ raise FileNotFoundError(filepath)
88
+ table_hyperplanes = np.load(filepath)
89
+ elif not allow_create_hyperplanes:
90
+ raise RuntimeError(
91
+ "Not allowed to create hyperplanes but no filepath provided"
92
+ )
93
+
94
+ self.table = LSHTable(
95
+ n_bits, dim, table_hyperplanes if table_hyperplanes is not None else None
96
+ )
97
+
98
+ def write_hyperplanes(self, filepath: PathLike):
99
+ np.save(filepath, self.table.hyperplanes)
100
+
101
+ def __call__(self, array):
102
+ return self.table(array, tokenize=False)