Spaces:
Sleeping
Sleeping
Migrate application to hugginface
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +4 -0
- README.md +93 -14
- app.py +151 -0
- frontend/constants.py +41 -0
- prot_xlstm_env.yml +39 -0
- protxlstm/__init__.py +1 -0
- protxlstm/applications/__init__.py +0 -0
- protxlstm/applications/fitness_prediction.py +214 -0
- protxlstm/applications/generation_utils/create_sequence_df.py +85 -0
- protxlstm/applications/generation_utils/score_hamming.py +80 -0
- protxlstm/applications/generation_utils/score_hmmer.py +102 -0
- protxlstm/applications/generation_utils/score_structure.py +55 -0
- protxlstm/applications/msa_sampler.py +196 -0
- protxlstm/applications/sample_sequences.py +200 -0
- protxlstm/applications/score_sequences.py +58 -0
- protxlstm/checkpoints/small/config.json +1 -0
- protxlstm/checkpoints/small/optimizer.pt +3 -0
- protxlstm/checkpoints/small/pytorch_model.bin +3 -0
- protxlstm/checkpoints/small/rng_state.pth +3 -0
- protxlstm/checkpoints/small/scheduler.pt +3 -0
- protxlstm/checkpoints/small/trainer_state.json +0 -0
- protxlstm/data.py +60 -0
- protxlstm/dataloaders.py +249 -0
- protxlstm/fim.py +203 -0
- protxlstm/generation.py +384 -0
- protxlstm/index.html +16 -0
- protxlstm/mamba_utils_generation.py +382 -0
- protxlstm/models/__init__.py +0 -0
- protxlstm/models/llama.py +342 -0
- protxlstm/models/mamba.py +833 -0
- protxlstm/models/xlstm.py +180 -0
- protxlstm/plot_utils.py +26 -0
- protxlstm/train.py +338 -0
- protxlstm/trainer.py +123 -0
- protxlstm/utils.py +482 -0
- protxlstm/xlstm/__init__.py +6 -0
- protxlstm/xlstm/blocks/__init__.py +0 -0
- protxlstm/xlstm/blocks/mlstm/__init__.py +1 -0
- protxlstm/xlstm/blocks/mlstm/backends.py +314 -0
- protxlstm/xlstm/blocks/mlstm/block.py +27 -0
- protxlstm/xlstm/blocks/mlstm/cell.py +212 -0
- protxlstm/xlstm/blocks/mlstm/layer.py +217 -0
- protxlstm/xlstm/blocks/xlstm_block.py +111 -0
- protxlstm/xlstm/components/__init__.py +0 -0
- protxlstm/xlstm/components/conv.py +163 -0
- protxlstm/xlstm/components/feedforward.py +88 -0
- protxlstm/xlstm/components/init.py +32 -0
- protxlstm/xlstm/components/linear_headwise.py +92 -0
- protxlstm/xlstm/components/ln.py +68 -0
- protxlstm/xlstm/components/rotary_position.py +35 -0
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*/*.pyc
|
2 |
+
*.pcy
|
3 |
+
*/__pycache__
|
4 |
+
__pycache__
|
README.md
CHANGED
@@ -1,14 +1,93 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# prot-xLSTM-app
|
2 |
+
|
3 |
+
|
4 |
+
|
5 |
+
## Getting started
|
6 |
+
|
7 |
+
To make it easy for you to get started with GitLab, here's a list of recommended next steps.
|
8 |
+
|
9 |
+
Already a pro? Just edit this README.md and make it your own. Want to make it easy? [Use the template at the bottom](#editing-this-readme)!
|
10 |
+
|
11 |
+
## Add your files
|
12 |
+
|
13 |
+
- [ ] [Create](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#create-a-file) or [upload](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#upload-a-file) files
|
14 |
+
- [ ] [Add files using the command line](https://docs.gitlab.com/ee/gitlab-basics/add-file.html#add-a-file-using-the-command-line) or push an existing Git repository with the following command:
|
15 |
+
|
16 |
+
```
|
17 |
+
cd existing_repo
|
18 |
+
git remote add origin https://git.bioinf.jku.at/chemoinformatics/prot-xlstm-app.git
|
19 |
+
git branch -M main
|
20 |
+
git push -uf origin main
|
21 |
+
```
|
22 |
+
|
23 |
+
## Integrate with your tools
|
24 |
+
|
25 |
+
- [ ] [Set up project integrations](https://git.bioinf.jku.at/chemoinformatics/prot-xlstm-app/-/settings/integrations)
|
26 |
+
|
27 |
+
## Collaborate with your team
|
28 |
+
|
29 |
+
- [ ] [Invite team members and collaborators](https://docs.gitlab.com/ee/user/project/members/)
|
30 |
+
- [ ] [Create a new merge request](https://docs.gitlab.com/ee/user/project/merge_requests/creating_merge_requests.html)
|
31 |
+
- [ ] [Automatically close issues from merge requests](https://docs.gitlab.com/ee/user/project/issues/managing_issues.html#closing-issues-automatically)
|
32 |
+
- [ ] [Enable merge request approvals](https://docs.gitlab.com/ee/user/project/merge_requests/approvals/)
|
33 |
+
- [ ] [Set auto-merge](https://docs.gitlab.com/ee/user/project/merge_requests/merge_when_pipeline_succeeds.html)
|
34 |
+
|
35 |
+
## Test and Deploy
|
36 |
+
|
37 |
+
Use the built-in continuous integration in GitLab.
|
38 |
+
|
39 |
+
- [ ] [Get started with GitLab CI/CD](https://docs.gitlab.com/ee/ci/quick_start/)
|
40 |
+
- [ ] [Analyze your code for known vulnerabilities with Static Application Security Testing (SAST)](https://docs.gitlab.com/ee/user/application_security/sast/)
|
41 |
+
- [ ] [Deploy to Kubernetes, Amazon EC2, or Amazon ECS using Auto Deploy](https://docs.gitlab.com/ee/topics/autodevops/requirements.html)
|
42 |
+
- [ ] [Use pull-based deployments for improved Kubernetes management](https://docs.gitlab.com/ee/user/clusters/agent/)
|
43 |
+
- [ ] [Set up protected environments](https://docs.gitlab.com/ee/ci/environments/protected_environments.html)
|
44 |
+
|
45 |
+
***
|
46 |
+
|
47 |
+
# Editing this README
|
48 |
+
|
49 |
+
When you're ready to make this README your own, just edit this file and use the handy template below (or feel free to structure it however you want - this is just a starting point!). Thanks to [makeareadme.com](https://www.makeareadme.com/) for this template.
|
50 |
+
|
51 |
+
## Suggestions for a good README
|
52 |
+
|
53 |
+
Every project is different, so consider which of these sections apply to yours. The sections used in the template are suggestions for most open source projects. Also keep in mind that while a README can be too long and detailed, too long is better than too short. If you think your README is too long, consider utilizing another form of documentation rather than cutting out information.
|
54 |
+
|
55 |
+
## Name
|
56 |
+
Choose a self-explaining name for your project.
|
57 |
+
|
58 |
+
## Description
|
59 |
+
Let people know what your project can do specifically. Provide context and add a link to any reference visitors might be unfamiliar with. A list of Features or a Background subsection can also be added here. If there are alternatives to your project, this is a good place to list differentiating factors.
|
60 |
+
|
61 |
+
## Badges
|
62 |
+
On some READMEs, you may see small images that convey metadata, such as whether or not all the tests are passing for the project. You can use Shields to add some to your README. Many services also have instructions for adding a badge.
|
63 |
+
|
64 |
+
## Visuals
|
65 |
+
Depending on what you are making, it can be a good idea to include screenshots or even a video (you'll frequently see GIFs rather than actual videos). Tools like ttygif can help, but check out Asciinema for a more sophisticated method.
|
66 |
+
|
67 |
+
## Installation
|
68 |
+
Within a particular ecosystem, there may be a common way of installing things, such as using Yarn, NuGet, or Homebrew. However, consider the possibility that whoever is reading your README is a novice and would like more guidance. Listing specific steps helps remove ambiguity and gets people to using your project as quickly as possible. If it only runs in a specific context like a particular programming language version or operating system or has dependencies that have to be installed manually, also add a Requirements subsection.
|
69 |
+
|
70 |
+
## Usage
|
71 |
+
Use examples liberally, and show the expected output if you can. It's helpful to have inline the smallest example of usage that you can demonstrate, while providing links to more sophisticated examples if they are too long to reasonably include in the README.
|
72 |
+
|
73 |
+
## Support
|
74 |
+
Tell people where they can go to for help. It can be any combination of an issue tracker, a chat room, an email address, etc.
|
75 |
+
|
76 |
+
## Roadmap
|
77 |
+
If you have ideas for releases in the future, it is a good idea to list them in the README.
|
78 |
+
|
79 |
+
## Contributing
|
80 |
+
State if you are open to contributions and what your requirements are for accepting them.
|
81 |
+
|
82 |
+
For people who want to make changes to your project, it's helpful to have some documentation on how to get started. Perhaps there is a script that they should run or some environment variables that they need to set. Make these steps explicit. These instructions could also be useful to your future self.
|
83 |
+
|
84 |
+
You can also document commands to lint the code or run tests. These steps help to ensure high code quality and reduce the likelihood that the changes inadvertently break something. Having instructions for running tests is especially helpful if it requires external setup, such as starting a Selenium server for testing in a browser.
|
85 |
+
|
86 |
+
## Authors and acknowledgment
|
87 |
+
Show your appreciation to those who have contributed to the project.
|
88 |
+
|
89 |
+
## License
|
90 |
+
For open source projects, say how it is licensed.
|
91 |
+
|
92 |
+
## Project status
|
93 |
+
If you have run out of energy or time for your project, put a note at the top of the README saying that development has slowed down or stopped completely. Someone may choose to fork your project or volunteer to step in as a maintainer or owner, allowing your project to keep going. You can also make an explicit request for maintainers.
|
app.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "" #disable cuda
|
3 |
+
|
4 |
+
import streamlit as st
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import time
|
8 |
+
from Bio import SeqIO
|
9 |
+
|
10 |
+
from protxlstm.applications.fitness_prediction import single_mutation_landscape_xlstm, create_mutation_df
|
11 |
+
from protxlstm.applications.msa_sampler import sample_msa
|
12 |
+
from protxlstm.models.xlstm import xLSTMLMHeadModel
|
13 |
+
from protxlstm.utils import load_model
|
14 |
+
import io
|
15 |
+
|
16 |
+
from frontend.constants import info_text, citation_text
|
17 |
+
|
18 |
+
DEFAULT_SEQUENCE = "MTARGLALGLLLLLLCPAQVFSQSCVWYGECGIAYGDKRYNCEYSGPPKPLPKDGYDLVQELCPGFFFGNVSLCCDVRQLQTLKDNLQLPLQFLSRCPSCFYNLLNLFCELTCSPRQSQFLNVTATEDYVDPVTNQTKTNVKELQYYVGQSFANAMYNACRDVEAPSSNDKALGLLCGKDADACNATNWIEYMFNKDNGQAPFTITPVFSDFPVHGMEPMNNATKGCDESVDEVTAPCSCQDCSIVCGPKPQPPPPPAPWTILGLDAMYVIMWITYMAFLLVFFGAFFAVWCYRKRYFVSEYTPIDSNIAFSVNASDKGEASCCDPVSAAFEGCLRRLFTRWGSFCVRNPGCVIFFSLVFITACSSGLVFVRVTTNPVDLWSAPSSQARLEKEYFDQHFGPFFRTEQLIIRAPLTDKHIYQPYPSGADVPFGPPLDIQILHQVLDLQIAIENITASYDNETVTLQDICLAPLSPYNTNCTILSVLNYFQNSHSVLDHKKGDDFFVYADYHTHFLYCVRAPASLNDTSLLHDPCLGTFGGPVFPWLVLGGYDDQNYNNATALVITFPVNNYYNDTEKLQRAQAWEKEFINFVKNYKNPNLTISFTAERSIEDELNRESDSDVFTVVISYAIMFLYISLALGHMKSCRRLLVDSKVSLGIAGILIVLSSVACSLGVFSYIGLPLTLIVIEVIPFLVLAVGVDNIFILVQAYQRDERLQGETLDQQLGRVLGEVAPSMFLSSFSETVAFFLGALSVMPAVHTFSLFAGLAVFIDFLLQITCFVSLLGLDIKRQEKNRLDIFCCVRGAEDGTSVQASESCLFRFFKNSYSPLLLKDWMRPIVIAIFVGVLSFSIAVLNKVDIGLDQSLSMPDDSYMVDYFKSISQYLHAGPPVYFVLEEGHDYTSSKGQNMVCGGMGCNNDSLVQQIFNAAQLDNYTRIGFAPSSWIDDYFDWVKPQSSCCRVDNITDQFCNASVVDPACVRCRPLTPEGKQRPQGGDFMRFLPMFLSDNPNPKCGKGGHAAYSSAVNILLGHGTRVGATYFMTYHTVLQTSADFIDALKKARLIASNVTETMGINGSAYRVFPYSVFYVFYEQYLTIIDDTIFNLGVSLGAIFLVTMVLLGCELWSAVIMCATIAMVLVNMFGVMWLWGISLNAVSLVNLVMSCGISVEFCSHITRAFTVSMKGSRVERAEEALAHMGSSVFSGITLTKFGGIVVLAFAKSQIFQIFYFRMYLAMVLLGATHGLIFLPVLLSYIGPSVNKAKSCATEERYKGTERERLLNF"
|
19 |
+
|
20 |
+
mutation_positions = []
|
21 |
+
msa_file = None
|
22 |
+
|
23 |
+
if 'fitness_done' not in st.session_state:
|
24 |
+
st.session_state.fitness_done = False
|
25 |
+
st.session_state.mutations = None
|
26 |
+
st.session_state.fitness_duration = None
|
27 |
+
st.session_state.target_sequence = ""
|
28 |
+
st.session_state.context_sequences = []
|
29 |
+
st.session_state.num_context_sequences = 25
|
30 |
+
|
31 |
+
def run_model():
|
32 |
+
try:
|
33 |
+
st.session_state.fitness_duration = time.time()
|
34 |
+
checkpoint = "protxlstm/checkpoints/small"
|
35 |
+
num_context_tokens = 2**15
|
36 |
+
df_mutations = create_mutation_df(st.session_state.target_sequence, mutation_positions)
|
37 |
+
if msa_file != None and st.session_state.num_context_sequences != 0:
|
38 |
+
def load_sequences_from_msa_file(file_obj):
|
39 |
+
text_io = io.TextIOWrapper(file_obj, encoding="utf-8")
|
40 |
+
sequences = [str(record.seq) for record in SeqIO.parse(text_io, "fasta")]
|
41 |
+
return sequences
|
42 |
+
msa_sequences = [msa.upper() for msa in load_sequences_from_msa_file(msa_file)]
|
43 |
+
st.session_state.context_sequences = sample_msa(msa_sequences, max_context_sequences=st.session_state.num_context_sequences, context_length=num_context_tokens)
|
44 |
+
st.session_state.context_sequences += [st.session_state.target_sequence]
|
45 |
+
|
46 |
+
config_update_kwargs = {
|
47 |
+
"mlstm_backend": "chunkwise_variable",
|
48 |
+
"mlstm_chunksize": 1024,
|
49 |
+
"mlstm_return_last_state": True}
|
50 |
+
|
51 |
+
model = load_model(
|
52 |
+
checkpoint,
|
53 |
+
model_class=xLSTMLMHeadModel,
|
54 |
+
device='cpu',
|
55 |
+
dtype=torch.bfloat16,
|
56 |
+
**config_update_kwargs,
|
57 |
+
)
|
58 |
+
model = model.eval()
|
59 |
+
st.session_state.mutations, _ = single_mutation_landscape_xlstm(model, df_mutations, st.session_state.context_sequences, chunk_chunk_size=2**15)
|
60 |
+
print("fitness_done")
|
61 |
+
st.session_state.fitness_done = True
|
62 |
+
st.session_state.fitness_duration = time.time() - st.session_state.fitness_duration
|
63 |
+
except Exception as e:
|
64 |
+
print(e)
|
65 |
+
|
66 |
+
# PAGE STYLE (mainly for custom aa selection)
|
67 |
+
st.set_page_config(layout="wide")
|
68 |
+
st.markdown(
|
69 |
+
"""
|
70 |
+
<style>
|
71 |
+
.stButtonGroup button {
|
72 |
+
padding: 0px 1px 0px 1px !important;
|
73 |
+
border: 0 solid transparent !important;
|
74 |
+
min-height: 0px !important;
|
75 |
+
line-height: 120% !important;
|
76 |
+
height: auto !important;
|
77 |
+
}
|
78 |
+
.stSidebar {
|
79 |
+
width: 600px !important;
|
80 |
+
}
|
81 |
+
</style>
|
82 |
+
""",
|
83 |
+
unsafe_allow_html=True
|
84 |
+
)
|
85 |
+
|
86 |
+
|
87 |
+
with st.sidebar:
|
88 |
+
st.title("Prot-xLSTM Variant Fitness")
|
89 |
+
|
90 |
+
# LOAD SEQUENCE
|
91 |
+
st.session_state.target_sequence = st.text_area(
|
92 |
+
"Target protein sequence",
|
93 |
+
placeholder=DEFAULT_SEQUENCE,
|
94 |
+
value=st.session_state.target_sequence
|
95 |
+
)
|
96 |
+
if st.button("Load sequence"):
|
97 |
+
if st.session_state.target_sequence == "":
|
98 |
+
st.session_state.target_sequence = DEFAULT_SEQUENCE
|
99 |
+
|
100 |
+
# MANAGE CONTEXT SEQUENCES
|
101 |
+
context_type = st.selectbox(
|
102 |
+
"Choose how to enter context",
|
103 |
+
("Enter manually", "Use MSA file"),
|
104 |
+
index=None,
|
105 |
+
placeholder="Choose context",
|
106 |
+
)
|
107 |
+
if context_type == 'Enter manually':
|
108 |
+
context_sequence_str = st.text_area(
|
109 |
+
"Enter context protein sequences (seperated by comma)",
|
110 |
+
placeholder=DEFAULT_SEQUENCE,
|
111 |
+
)
|
112 |
+
st.session_state.context_sequences = context_sequence_str.split(",") + [st.session_state.target_sequence]
|
113 |
+
elif context_type == 'Use MSA file':
|
114 |
+
msa_file = st.file_uploader("Choose MSA file")
|
115 |
+
st.session_state.num_context_sequences = st.number_input("How many of these sequences should be used?", min_value=0, step=1, value=25)
|
116 |
+
else:
|
117 |
+
st.session_state.context_sequences = [st.session_state.target_sequence]
|
118 |
+
|
119 |
+
if st.session_state.target_sequence != "":
|
120 |
+
with st.container():
|
121 |
+
|
122 |
+
# MUTATION POSITION SELECTION
|
123 |
+
aas = list(st.session_state.target_sequence)
|
124 |
+
mutation_indices = np.arange(1, len(aas)+1)
|
125 |
+
mutation_positions = st.segmented_control(
|
126 |
+
"Choose mutation positions (click to select)", mutation_indices, selection_mode="multi", format_func=lambda i: aas[i-1],
|
127 |
+
)
|
128 |
+
st.button("Check Fitness", on_click=run_model)
|
129 |
+
|
130 |
+
# DISPLAY RESULTS
|
131 |
+
if st.session_state.fitness_done:
|
132 |
+
st.metric(label="Running time", value=f"{st.session_state.fitness_duration:.2f} sec.")
|
133 |
+
selected_pos = st.selectbox(
|
134 |
+
"Visualized mutation position",
|
135 |
+
st.session_state.mutations['position'].unique()
|
136 |
+
)
|
137 |
+
selected_data = st.session_state.mutations.where(st.session_state.mutations['position'] == selected_pos)
|
138 |
+
st.bar_chart(selected_data, x='mutation', y='effect', horizontal=True)
|
139 |
+
st.dataframe(st.session_state.mutations, use_container_width=True)
|
140 |
+
|
141 |
+
# TUTORIAL
|
142 |
+
with st.expander("Info & Tutorial", expanded=True):
|
143 |
+
st.subheader("Tutorial")
|
144 |
+
st.markdown("**1.** Choose a target protein sequence (leave empty to use a sample sequence) and press 'Load Sequence'")
|
145 |
+
st.markdown("**2.** Enter or upload you context sequences. (leave empty to use no context)")
|
146 |
+
st.markdown("**3.** Choose which amino acids to mutate (click on the AA's to select them) and press 'Check Fitness'")
|
147 |
+
st.subheader("General Information")
|
148 |
+
st.markdown(info_text, unsafe_allow_html=True)
|
149 |
+
st.markdown("")
|
150 |
+
st.subheader("Cite us / BibTex")
|
151 |
+
st.code(citation_text, language=None)
|
frontend/constants.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
info_text = ("""
|
2 |
+
<div style="text-align: justify">
|
3 |
+
This application enables to inspect mutational effects on a
|
4 |
+
predefined protein sequence.<br>
|
5 |
+
<br>
|
6 |
+
</div>
|
7 |
+
|
8 |
+
<div style="text-align: justify">
|
9 |
+
It is built on the Prot-xLSTM backbone model, an xLSTM model specifically
|
10 |
+
trained on protein sequences. Prot-xLSTM was trained using the
|
11 |
+
Fill-In-the-Middle (FIM) objective, which allows it to perform sequence
|
12 |
+
inpainting. Additionally, the model can be provided with a potentially
|
13 |
+
large set of homologous sequences to enhance its predictions.<br>
|
14 |
+
<br>
|
15 |
+
</div>
|
16 |
+
|
17 |
+
<div style="text-align: justify">
|
18 |
+
For further information please refer, to: <a href="https://openreview.net/forum?id=IjbXZdugdj" target="_blank">https://openreview.net/forum?id=IjbXZdugdj</a>. <br>
|
19 |
+
<br>
|
20 |
+
|
21 |
+
This Hugging Face application is based on the following GitHub repository:
|
22 |
+
<a href="https://github.com/ml-jku/Prot-xLSTM?tab=readme-ov-file" target="_blank">https://github.com/ml-jku/Prot-xLSTM?tab=readme-ov-file</a>. <br>
|
23 |
+
The streamlit application was developed by Elias Bürger.
|
24 |
+
</div>
|
25 |
+
|
26 |
+
<div style="text-align: justify">
|
27 |
+
Please cite us as follows: <br>
|
28 |
+
</div>
|
29 |
+
""")
|
30 |
+
citation_text = """
|
31 |
+
@misc{
|
32 |
+
schmidinger2024bioxlstmgenerativemodelingrepresentation,
|
33 |
+
title={Bio-xLSTM: Generative modeling, representation and in-context learning of biological and chemical sequences},
|
34 |
+
author={Niklas Schmidinger and Lisa Schneckenreiter and Philipp Seidl and Johannes Schimunek and Pieter-Jan Hoedt and Johannes Brandstetter and Andreas Mayr and Sohvi Luukkonen and Sepp Hochreiter and Günter Klambauer},
|
35 |
+
year={2024},
|
36 |
+
eprint={2411.04165},
|
37 |
+
archivePrefix={arXiv},
|
38 |
+
primaryClass={q-bio.BM},
|
39 |
+
url={https://arxiv.org/abs/2411.04165},
|
40 |
+
}
|
41 |
+
"""
|
prot_xlstm_env.yml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: prot_xlstm_app
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
- conda-forge
|
6 |
+
- defaults
|
7 |
+
dependencies:
|
8 |
+
- cuda=12.1
|
9 |
+
- cuda-nvcc=12.1
|
10 |
+
- gxx_linux-64=11.2.0
|
11 |
+
- python=3.11
|
12 |
+
- pip
|
13 |
+
- pytorch=2.2.0
|
14 |
+
- pytorch-cuda=12.1
|
15 |
+
- cmake
|
16 |
+
- ninja
|
17 |
+
- pip:
|
18 |
+
- accelerate>=0.26.0
|
19 |
+
- biopython #==1.83
|
20 |
+
- bottleneck #==1.4.2
|
21 |
+
- dacite #==1.8.1
|
22 |
+
- ipykernel #==6.29.3
|
23 |
+
- mamba_ssm==1.2.0
|
24 |
+
- matplotlib #==3.8.4
|
25 |
+
- numpy<2.0 #==1.26.4
|
26 |
+
- omegaconf #==2.3.0
|
27 |
+
- pandas #==2.2.2
|
28 |
+
- pyhmmer #==0.10.15
|
29 |
+
- rich #==13.7.1
|
30 |
+
- scipy #==1.13.0
|
31 |
+
- seaborn #==0.13.2
|
32 |
+
- torchmetrics #==1.2.1
|
33 |
+
- tqdm #==4.66.4
|
34 |
+
- transformers==4.44.2
|
35 |
+
- tueplots #==0.0.17
|
36 |
+
- wandb #==0.17.0
|
37 |
+
- streamlit #==1.43.2
|
38 |
+
|
39 |
+
|
protxlstm/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# __version__ = "0.0.1"
|
protxlstm/applications/__init__.py
ADDED
File without changes
|
protxlstm/applications/fitness_prediction.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from tqdm.auto import tqdm
|
7 |
+
|
8 |
+
from protxlstm.applications.msa_sampler import MSASampler
|
9 |
+
from protxlstm.generation import generate_sequence
|
10 |
+
from protxlstm.utils import AA_TO_ID, tokenizer, ID_TO_AA
|
11 |
+
|
12 |
+
|
13 |
+
def precompute_context_state(model, sequences, chunk_chunk_size=2**15):
|
14 |
+
"""
|
15 |
+
Precompute the output states for a fixed context that remains the same across generations.
|
16 |
+
Returns the hidden states to continue generation later.
|
17 |
+
"""
|
18 |
+
device = next(model.parameters()).device
|
19 |
+
|
20 |
+
input_ids, pos_ids = prepare_context(sequences)
|
21 |
+
state = None
|
22 |
+
|
23 |
+
for chunk in range(input_ids.shape[1]//chunk_chunk_size+1):
|
24 |
+
|
25 |
+
start_idx = chunk*chunk_chunk_size
|
26 |
+
end_idx = min((chunk+1)*chunk_chunk_size, input_ids.shape[1])
|
27 |
+
|
28 |
+
if start_idx == end_idx:
|
29 |
+
pass
|
30 |
+
|
31 |
+
else:
|
32 |
+
input_ids_chunk = input_ids[:, start_idx:end_idx].to(device)
|
33 |
+
pos_ids_chunk = pos_ids[:, start_idx:end_idx].to(device)
|
34 |
+
|
35 |
+
with torch.no_grad():
|
36 |
+
outputs = model(input_ids=input_ids_chunk,
|
37 |
+
position_ids=pos_ids_chunk,
|
38 |
+
state=state,
|
39 |
+
output_hidden_states=True,
|
40 |
+
return_dict=True)
|
41 |
+
state = outputs.state
|
42 |
+
|
43 |
+
# Return the hidden states for reuse
|
44 |
+
return state
|
45 |
+
|
46 |
+
def prepare_context(sequences):
|
47 |
+
tokenized_sequences = tokenizer(sequences, concatenate=False)
|
48 |
+
pos_ids = torch.cat([torch.arange(0, len(seq), dtype=torch.int64) for seq in tokenized_sequences], 0)[None, :]
|
49 |
+
input_ids = torch.cat(tokenized_sequences, 0)[None, :].to(torch.int64)
|
50 |
+
return input_ids, pos_ids
|
51 |
+
|
52 |
+
def prepare_single_mutation_target(target, mut_pos):
|
53 |
+
|
54 |
+
pos_ids = torch.arange(target.shape[1], dtype=torch.int64)[None,:] # default position ids
|
55 |
+
t = torch.ones((target.shape[0], 1), dtype=torch.int64)
|
56 |
+
new_target = torch.cat([
|
57 |
+
target[:,:mut_pos], # WT sequence until mutated position
|
58 |
+
AA_TO_ID["<mask-1>"] * t, # Mask token at the muated position
|
59 |
+
target[:,mut_pos+1:], # WT sequence after mutated position
|
60 |
+
AA_TO_ID["<eos>"] * t, # End of sequence token
|
61 |
+
AA_TO_ID["<mask-1>"] * t, # Mask token
|
62 |
+
], dim=1)
|
63 |
+
new_pos_ids = torch.cat([
|
64 |
+
pos_ids,
|
65 |
+
0 * t, # end of sequence
|
66 |
+
mut_pos * t, # mutation position
|
67 |
+
], dim=1)
|
68 |
+
|
69 |
+
is_fim_dict = { AA_TO_ID["<mask-1>"] : pos_ids[:,mut_pos].squeeze().item()}
|
70 |
+
|
71 |
+
return new_target, new_pos_ids, is_fim_dict
|
72 |
+
|
73 |
+
def single_mutation_landscape_xlstm(model, single_mutations, context_sequences, chunk_chunk_size=2**15):
|
74 |
+
|
75 |
+
device = next(model.parameters()).device
|
76 |
+
|
77 |
+
# Tokenize WT target sequence
|
78 |
+
wt_tokens = tokenizer([context_sequences[-1]], concatenate=True)
|
79 |
+
|
80 |
+
# Precompute hidden state of context
|
81 |
+
context_state = precompute_context_state(model, context_sequences, chunk_chunk_size=chunk_chunk_size)
|
82 |
+
|
83 |
+
mutation_positions = sorted(single_mutations.position.unique())
|
84 |
+
all_logits = np.zeros((len(mutation_positions), 20))
|
85 |
+
|
86 |
+
# Iterate over all mutated positions
|
87 |
+
for i, pos in tqdm(enumerate(mutation_positions), total=len(mutation_positions), desc="Generating mutational landscape"): # This loop can be parallelized
|
88 |
+
|
89 |
+
# Prepare target
|
90 |
+
wt_aa_id = wt_tokens[0, pos+1].int().item() # wild type AA index
|
91 |
+
target_tokens, target_pos_ids, _ = prepare_single_mutation_target(wt_tokens, pos+1)
|
92 |
+
|
93 |
+
with torch.no_grad():
|
94 |
+
outputs = model(input_ids=target_tokens.to(device),
|
95 |
+
position_ids=target_pos_ids.to(device),
|
96 |
+
state=context_state,
|
97 |
+
)
|
98 |
+
|
99 |
+
# Extact logits and compute mutational effect
|
100 |
+
logits = outputs.logits.clone().detach() # Raw logits
|
101 |
+
logits_mut = logits[0, -1, 4:24].log_softmax(-1) # Log-softmax for mutation prediction: (4-24) correspond to natural NNs
|
102 |
+
mut_effects = logits_mut - logits_mut[wt_aa_id - 4] # Subtract log probability of ground truth
|
103 |
+
all_logits[i,:] = logits_mut.cpu()
|
104 |
+
single_mutations.loc[single_mutations.position == pos, 'effect'] = single_mutations.loc[single_mutations.position == pos, 'mutation_idx'].apply(lambda x : mut_effects[x-4].item())
|
105 |
+
|
106 |
+
return single_mutations, all_logits
|
107 |
+
|
108 |
+
def single_mutation_landscape_mamba(model, single_mutations, context_sequences):
|
109 |
+
|
110 |
+
# Prepare context sequences
|
111 |
+
context_tokens, context_pos_ids = prepare_context(context_sequences)
|
112 |
+
|
113 |
+
# Tokenize WT target sequence
|
114 |
+
wt_tokens = tokenizer([context_sequences[-1]], concatenate=True)
|
115 |
+
|
116 |
+
mutation_positions = sorted(single_mutations.position.unique())
|
117 |
+
all_logits = np.zeros((len(mutation_positions), 20))
|
118 |
+
|
119 |
+
# Iterate over all mutated positions
|
120 |
+
for i, pos in tqdm(enumerate(mutation_positions), total=len(mutation_positions), desc="Generating mutational landscape"): # This loop can be parallelized
|
121 |
+
|
122 |
+
# Prepare target
|
123 |
+
wt_aa_id = wt_tokens[0, pos+1].int().item() # wild type AA index
|
124 |
+
target_tokens, target_pos_ids, is_fim_dict = prepare_single_mutation_target(wt_tokens, pos+1)
|
125 |
+
|
126 |
+
# Merge context and target
|
127 |
+
device = next(model.parameters()).device
|
128 |
+
context_tokens = torch.cat([context_tokens, target_tokens], dim=1).to(device)
|
129 |
+
context_pos_ids = torch.cat([context_pos_ids, target_pos_ids], dim=1).to(device)
|
130 |
+
|
131 |
+
# Generate fim-token prediction
|
132 |
+
output = generate_sequence(
|
133 |
+
model,
|
134 |
+
context_tokens,
|
135 |
+
position_ids=context_pos_ids,
|
136 |
+
is_fim=is_fim_dict,
|
137 |
+
max_length=1,
|
138 |
+
temperature=1.0,
|
139 |
+
top_k=0,
|
140 |
+
top_p=0.0,
|
141 |
+
return_dict_in_generate=True,
|
142 |
+
output_scores=True,
|
143 |
+
eos_token_id=AA_TO_ID["<cls>"],
|
144 |
+
device=device
|
145 |
+
)
|
146 |
+
|
147 |
+
# Extact logits and compute mutational effect
|
148 |
+
logits = torch.tensor(output["scores"]) # Raw logits
|
149 |
+
logits_mut = logits[0, 0, 4:24].log_softmax(-1) # Log-softmax for mutation prediction: (4-24) correspond to natural NNs
|
150 |
+
mut_effects = logits_mut - logits_mut[wt_aa_id - 4] # Subtract log probability of ground truth
|
151 |
+
all_logits[i,:] = logits_mut.cpu()
|
152 |
+
|
153 |
+
single_mutations.loc[single_mutations.position == pos, 'effect'] = single_mutations.loc[single_mutations.position == pos, 'mutation_idx'].apply(lambda x : mut_effects[x-4].item())
|
154 |
+
|
155 |
+
return single_mutations, all_logits
|
156 |
+
|
157 |
+
def single_mutation_landscape_retrieval(single_mutations, msa_sequences, msa_weights_path):
|
158 |
+
|
159 |
+
# One-hot encode MSA sequences
|
160 |
+
msa_tokens = np.array([[AA_TO_ID[aa.upper()] for aa in seq] for seq in msa_sequences])
|
161 |
+
one_hot_tokens = np.zeros((len(msa_tokens), len(msa_tokens[0]), 40))
|
162 |
+
one_hot_tokens[np.arange(len(msa_tokens))[:, None], np.arange(len(msa_tokens[0])), msa_tokens] = 1
|
163 |
+
|
164 |
+
#Load/compute weights
|
165 |
+
if os.path.exists(msa_weights_path):
|
166 |
+
weights = np.load(msa_weights_path)
|
167 |
+
else:
|
168 |
+
sampler = MSASampler(0.98, 0.7)
|
169 |
+
weights = sampler.get_weights(msa_tokens)[1]
|
170 |
+
np.save(msa_weights_path, weights)
|
171 |
+
assert one_hot_tokens.shape[0] == weights.shape[0]
|
172 |
+
|
173 |
+
# Apply sequence weights, normalize amino acid probabilities per position, and convert to a PyTorch tensor.
|
174 |
+
one_hot_tokens = one_hot_tokens * weights[:, None, None]
|
175 |
+
one_hot_tokens = one_hot_tokens.sum(0)
|
176 |
+
one_hot_tokens = one_hot_tokens[:, 4:24] + 1 / len(msa_sequences)
|
177 |
+
one_hot_tokens_sum = one_hot_tokens.sum(-1)
|
178 |
+
one_hot_tokens = one_hot_tokens / one_hot_tokens_sum[:, None]
|
179 |
+
one_hot_tokens = torch.tensor(one_hot_tokens).float()
|
180 |
+
|
181 |
+
# Compute mutational effects
|
182 |
+
wild_type = msa_tokens[0]
|
183 |
+
logits = one_hot_tokens.log()
|
184 |
+
logits = logits - logits[torch.arange(len(logits)), wild_type - 4][:, None]
|
185 |
+
|
186 |
+
single_mutations['retrieval_effect'] = single_mutations.apply(
|
187 |
+
lambda row: logits[row['position'], row['mutation_idx'] - 4].item(), axis=1)
|
188 |
+
|
189 |
+
return single_mutations
|
190 |
+
|
191 |
+
|
192 |
+
def create_mutation_df(sequence, mutation_positions):
|
193 |
+
"""
|
194 |
+
Generate a DataFrame containing all possible mutations at specified positions in a sequence.
|
195 |
+
|
196 |
+
Args:
|
197 |
+
sequence (str): The original sequence to mutate.
|
198 |
+
mutation_positions (list of int): List of positions to mutate (1-based index).
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
pd.DataFrame:
|
202 |
+
- 'mutation': formatted mutation string (e.g., 'A10G' for Ala at position 10 to Gly).
|
203 |
+
- 'position': 0-based position in the sequence.
|
204 |
+
- 'mutation_idx': numeric index for the mutation.
|
205 |
+
"""
|
206 |
+
|
207 |
+
AAs = {k: v for k, v in ID_TO_AA.items() if 4 <= k <= 23}
|
208 |
+
mutation_data = []
|
209 |
+
for position in mutation_positions:
|
210 |
+
wt = sequence[position - 1]
|
211 |
+
for idx, aa in AAs.items():
|
212 |
+
mutation = f"{wt}{position}{aa}"
|
213 |
+
mutation_data.append({'mutation': mutation, 'position': position - 1, 'mutation_idx': idx})
|
214 |
+
return pd.DataFrame(mutation_data)
|
protxlstm/applications/generation_utils/create_sequence_df.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pickle
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
from protxlstm.dataloaders import ProteinMemmapDataset
|
6 |
+
from protxlstm.utils import decode_sequence, reorder_masked_sequence
|
7 |
+
|
8 |
+
|
9 |
+
def create_sequence_df(model_name, family_idx, parameters_list=None, num_sequences = 100, data_dir="./data/"):
|
10 |
+
|
11 |
+
#load dataset
|
12 |
+
dataset = ProteinMemmapDataset(
|
13 |
+
msa_memmap_path=f"{data_dir}open_protein_set_memmap.dat",
|
14 |
+
msa_memmap_meta_path=f"{data_dir}open_protein_set_memmap_indices.csv",
|
15 |
+
subset_path=f"{data_dir}/cluster_testing_set.txt",
|
16 |
+
sample=False,
|
17 |
+
max_msa_len=-1,
|
18 |
+
reverse=False,
|
19 |
+
seed=0,
|
20 |
+
troubleshoot=False,
|
21 |
+
fim_strategy="multiple_span",
|
22 |
+
always_mask=False,
|
23 |
+
max_position_embeddings=2048,
|
24 |
+
max_seq_position_embeddings=512,
|
25 |
+
add_position_ids="1d",
|
26 |
+
mask_fraction=0.2,
|
27 |
+
max_patches=5,
|
28 |
+
)
|
29 |
+
|
30 |
+
family_id = list(dataset.dataset_meta["msa_id"])[family_idx]
|
31 |
+
|
32 |
+
if model_name == "natural":
|
33 |
+
|
34 |
+
data = dataset[family_idx]
|
35 |
+
sequence_df = pd.DataFrame(columns=["family", "family_id", "sequence", "sequence_length"])
|
36 |
+
tokens = data["input_ids"][None,:]
|
37 |
+
all_context = decode_sequence(tokens[0].cpu().numpy())
|
38 |
+
list_sequences_msa = [reorder_masked_sequence(elem+"<cls>") for elem in all_context.split("<cls>")[1:-1]]
|
39 |
+
|
40 |
+
rd_idxs = np.random.choice(len(list_sequences_msa), num_sequences, replace=False)
|
41 |
+
natural_sequences = [seq for i, seq in enumerate(list_sequences_msa) if i in rd_idxs]
|
42 |
+
|
43 |
+
df_dict = {"family": [family_idx]*len(natural_sequences),
|
44 |
+
"family_id": [family_id]*len(natural_sequences),
|
45 |
+
"sequence": natural_sequences,
|
46 |
+
"sequence_length": [len(seq) for seq in natural_sequences]}
|
47 |
+
|
48 |
+
sequence_df = pd.concat([sequence_df, pd.DataFrame(df_dict)], ignore_index = True)
|
49 |
+
|
50 |
+
else:
|
51 |
+
|
52 |
+
sequence_df = pd.DataFrame(columns=["family", "family_id", "n_seqs_ctx", "temperature", "top_k", "top_p", "original_sequence", "sequence", "sequence_length", "perplexity"])
|
53 |
+
|
54 |
+
if parameters_list is None:
|
55 |
+
parameters_list = [(10,1.,10,1.), (10,1.,15,1.), (10,1.,10,0.95), (10,0.9,10,0.95), (10,0.8,10,0.9),
|
56 |
+
(100,1.,10,1.), (100,1.,15,1.), (100,1.,10,0.95), (100,0.9,10,0.95), (100,0.8,10,0.9),
|
57 |
+
(500,1.,10,1.), (500,1.,15,1.), (500,1.,10,0.95), (500,0.9,10,0.95), (500,0.8,10,0.9),
|
58 |
+
(1000,1.,10,1.), (1000,1.,15,1.), (1000,1.,10,0.95), (1000,0.9,10,0.95), (1000,0.8,10,0.9),
|
59 |
+
(-1,1.,10,1.), (-1,1.,15,1.), (-1,1.,10,0.95), (-1,0.9,10,0.95), (-1,0.8,10,0.9)]
|
60 |
+
|
61 |
+
for param in parameters_list:
|
62 |
+
n_seqs_ctx, temperature, top_k, top_p = param
|
63 |
+
|
64 |
+
with open(f"evaluation/generation/generated_sequences/{model_name}/{family_idx}_{param}_{num_sequences}", "rb") as f:
|
65 |
+
gen_seqs = pickle.load(f)
|
66 |
+
|
67 |
+
original_sequences = list(gen_seqs[family_idx][param].keys())
|
68 |
+
reordered_sequences = [reorder_masked_sequence(seq) for seq in original_sequences]
|
69 |
+
perplexities = [gen_seqs[family_idx][param][seq]["perplexity"] for seq in original_sequences]
|
70 |
+
df_dict = {"family": [family_idx]*len(original_sequences),
|
71 |
+
"family_id": [family_id]*len(original_sequences),
|
72 |
+
"n_seqs_ctx": [n_seqs_ctx]*len(original_sequences),
|
73 |
+
"temperature": [temperature]*len(original_sequences),
|
74 |
+
"top_k": [top_k]*len(original_sequences),
|
75 |
+
"top_p": [top_p]*len(original_sequences),
|
76 |
+
"original_sequence": original_sequences,
|
77 |
+
"sequence": reordered_sequences,
|
78 |
+
"sequence_length": [len(seq) for seq in reordered_sequences],
|
79 |
+
"perplexity": perplexities
|
80 |
+
}
|
81 |
+
|
82 |
+
sequence_df = pd.concat([sequence_df, pd.DataFrame(df_dict)], ignore_index = True)
|
83 |
+
|
84 |
+
return sequence_df
|
85 |
+
|
protxlstm/applications/generation_utils/score_hamming.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from tqdm import tqdm
|
3 |
+
import pandas as pd
|
4 |
+
from Bio import Align
|
5 |
+
|
6 |
+
from protxlstm.dataloaders import ProteinMemmapDataset
|
7 |
+
from protxlstm.utils import decode_sequence, reorder_masked_sequence
|
8 |
+
|
9 |
+
|
10 |
+
aligner = Align.PairwiseAligner()
|
11 |
+
aligner.mode = 'global'
|
12 |
+
aligner.match_score = 1
|
13 |
+
aligner.mismatch_score = -1
|
14 |
+
aligner.open_gap_score = -1
|
15 |
+
aligner.extend_gap_score = -1
|
16 |
+
|
17 |
+
def align_sequences(ref_seq, query_seq, print_alignments=False):
|
18 |
+
def hamming_str(s1,s2):
|
19 |
+
assert len(s1) == len(s2)
|
20 |
+
return sum(np.array(list(s1)) != np.array(list(s2)))/len(s1)
|
21 |
+
alignments = aligner.align(ref_seq, query_seq)
|
22 |
+
if print_alignments:
|
23 |
+
print("Score = %.1f:" % alignments[0].score)
|
24 |
+
print(alignments[0])
|
25 |
+
return hamming_str(alignments[0][0], alignments[0][1]), alignments[0][0], alignments[0][1]
|
26 |
+
|
27 |
+
|
28 |
+
def score_hamming(sequence_df, family_idx, data_dir = f"./data/"):
|
29 |
+
|
30 |
+
assert len(set(list(sequence_df["family"]))) == 1 and sequence_df["family"].iloc[0] == family_idx
|
31 |
+
|
32 |
+
#load dataset
|
33 |
+
dataset = ProteinMemmapDataset(
|
34 |
+
msa_memmap_path=f"{data_dir}open_protein_set_memmap.dat",
|
35 |
+
msa_memmap_meta_path=f"{data_dir}open_protein_set_memmap_indices.csv",
|
36 |
+
subset_path=f"{data_dir}/cluster_testing_set.txt",
|
37 |
+
sample=False,
|
38 |
+
max_msa_len=-1,
|
39 |
+
reverse=False,
|
40 |
+
seed=0,
|
41 |
+
troubleshoot=False,
|
42 |
+
fim_strategy="multiple_span",
|
43 |
+
always_mask=False,
|
44 |
+
max_position_embeddings=2048,
|
45 |
+
max_seq_position_embeddings=512,
|
46 |
+
add_position_ids="1d",
|
47 |
+
mask_fraction=0.2,
|
48 |
+
max_patches=5,
|
49 |
+
)
|
50 |
+
|
51 |
+
# Select a sample of the dataset to be the input
|
52 |
+
data = dataset[family_idx]
|
53 |
+
tokens = data["input_ids"][None,:]
|
54 |
+
all_context = decode_sequence(tokens[0].cpu().numpy())
|
55 |
+
list_sequences_msa = [reorder_masked_sequence(elem+"<cls>") for elem in all_context.split("<cls>")[1:-1]]
|
56 |
+
|
57 |
+
# sequence_df["hamming"] = pd.Series(dtype=object)
|
58 |
+
sequence_df["min_hamming"] = pd.Series()
|
59 |
+
sequence_df["median_hamming"] = pd.Series()
|
60 |
+
sequence_df["mean_hamming"] = pd.Series()
|
61 |
+
sequence_df["std_hamming"] = pd.Series()
|
62 |
+
|
63 |
+
for seq in tqdm(list(sequence_df["sequence"])):
|
64 |
+
|
65 |
+
all_hamming = []
|
66 |
+
for ctx_seq in list_sequences_msa:
|
67 |
+
if ctx_seq == seq:
|
68 |
+
continue
|
69 |
+
else:
|
70 |
+
hamming, _, _ = align_sequences(ctx_seq, seq , print_alignments=False)
|
71 |
+
all_hamming.append(hamming)
|
72 |
+
|
73 |
+
# sequence_df.loc[sequence_df["sequence"] == seq, "hamming"] = [all_hamming]
|
74 |
+
sequence_df.loc[sequence_df["sequence"] == seq, "min_hamming"] = np.min(all_hamming)
|
75 |
+
sequence_df.loc[sequence_df["sequence"] == seq, "median_hamming"] = np.median(all_hamming)
|
76 |
+
sequence_df.loc[sequence_df["sequence"] == seq, "mean_hamming"] = np.mean(all_hamming)
|
77 |
+
sequence_df.loc[sequence_df["sequence"] == seq, "std_hamming"] = np.std(all_hamming)
|
78 |
+
|
79 |
+
return sequence_df
|
80 |
+
|
protxlstm/applications/generation_utils/score_hmmer.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import string
|
2 |
+
from Bio import SeqIO
|
3 |
+
import pyhmmer
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
alphabet = pyhmmer.easel.Alphabet.amino()
|
7 |
+
|
8 |
+
# This is an efficient way to delete lowercase characters and insertion characters from a string
|
9 |
+
deletekeys = dict.fromkeys(string.ascii_lowercase)
|
10 |
+
deletekeys["."] = None
|
11 |
+
deletekeys["*"] = None
|
12 |
+
translation = str.maketrans(deletekeys)
|
13 |
+
|
14 |
+
def remove_insertions(sequence: str) -> str:
|
15 |
+
""" Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """
|
16 |
+
return sequence.translate(translation)
|
17 |
+
|
18 |
+
def read_msa(filename: str):
|
19 |
+
""" Reads the sequences from an MSA file, automatically removes insertions."""
|
20 |
+
return [(record.description, remove_insertions(str(record.seq))) for record in SeqIO.parse(filename, "fasta")]
|
21 |
+
|
22 |
+
def read_msa_unaligned(filename: str):
|
23 |
+
""" Reads the sequences from an MSA file, removes only . - and * characters."""
|
24 |
+
return [(record.description, str(record.seq).replace(".","").replace("-","").replace("*","").upper()) for record in SeqIO.parse(filename, "fasta")]
|
25 |
+
|
26 |
+
def check_msa(msa):
|
27 |
+
""" Checks if there are any repeated sequences in the MSA"""
|
28 |
+
seqs = set()
|
29 |
+
for el in msa:
|
30 |
+
seqs.add(el[1])
|
31 |
+
assert len(seqs) == len(msa), "There are repeated sequences in the MSA"
|
32 |
+
|
33 |
+
def make_hmm_from_a3m_msa(msa_filepath, hmm_filename=None):
|
34 |
+
# Load MSA from a3m
|
35 |
+
msa_tup = read_msa(msa_filepath)
|
36 |
+
# check_msa(msa_tup)
|
37 |
+
# Create digitized MSA block
|
38 |
+
all_seqs = [pyhmmer.easel.TextSequence(name=str(i).encode("utf-8"), sequence=seq) for i, (idz, seq) in enumerate(msa_tup)]
|
39 |
+
msa = pyhmmer.easel.TextMSA(name=b"msa", sequences=all_seqs)
|
40 |
+
msa = msa.digitize(alphabet)
|
41 |
+
# Fit HMM
|
42 |
+
builder = pyhmmer.plan7.Builder(alphabet)
|
43 |
+
background = pyhmmer.plan7.Background(alphabet)
|
44 |
+
hmm, _, _ = builder.build_msa(msa, background)
|
45 |
+
if hmm_filename is not None:
|
46 |
+
with open(f"{hmm_filename}.hmm", "wb") as output_file:
|
47 |
+
hmm.write(output_file)
|
48 |
+
return hmm
|
49 |
+
|
50 |
+
def align_and_score_sequences_in_a3m_with_hmm(hmm, sequences_path=None, sequences_list=None):
|
51 |
+
if sequences_list is not None:
|
52 |
+
msa = sequences_list
|
53 |
+
all_seqs = [pyhmmer.easel.TextSequence(name=str(i).encode("utf-8"), sequence=seq) for i, seq in enumerate(sequences_list)]
|
54 |
+
elif sequences_path is not None:
|
55 |
+
# Load sequences from a3m
|
56 |
+
msa = read_msa_unaligned(sequences_path)
|
57 |
+
all_seqs = [pyhmmer.easel.TextSequence(name=str(i).encode("utf-8"), sequence=seq) for i, (idz, seq) in enumerate(msa)]
|
58 |
+
else:
|
59 |
+
raise NotImplementedError("Missing sequences to align/score")
|
60 |
+
# Create digitized Sequence block
|
61 |
+
seq_block = pyhmmer.easel.TextSequenceBlock(all_seqs)
|
62 |
+
seq_block = seq_block.digitize(alphabet)
|
63 |
+
# Get all hits from the hmm
|
64 |
+
background = pyhmmer.plan7.Background(alphabet)
|
65 |
+
pipeline = pyhmmer.plan7.Pipeline(alphabet, background=background, bias_filter=False, F1=1.0, F2=1.0, F3=1.0)
|
66 |
+
hits = pipeline.search_hmm(hmm, seq_block)
|
67 |
+
if len(hits) != len(msa):
|
68 |
+
print(f"Number of hits: {len(hits)} is different from the number of sequences in the MSA: {len(msa)}")
|
69 |
+
# Extract hits
|
70 |
+
all_hits = {}
|
71 |
+
for hit in hits:
|
72 |
+
idz, score, evalue = hit.name, hit.score, hit.evalue
|
73 |
+
i = int(idz.decode("utf-8"))
|
74 |
+
seq = msa[i][1] if sequences_path is not None else sequences_list[i]
|
75 |
+
all_hits[seq] = {"score": score, "evalue": evalue}
|
76 |
+
return all_hits
|
77 |
+
|
78 |
+
|
79 |
+
def score_hmmer(sequence_df, family_idx, data_dir = f"./data/"):
|
80 |
+
|
81 |
+
assert len(set(list(sequence_df["family"]))) == 1 and sequence_df["family"].iloc[0] == family_idx
|
82 |
+
|
83 |
+
family_id = sequence_df["family_id"].iloc[0]
|
84 |
+
msa_filepath = f"{data_dir}/a3m_files/{family_id}/a3m/uniclust30.a3m"
|
85 |
+
try:
|
86 |
+
hmm = make_hmm_from_a3m_msa(msa_filepath)
|
87 |
+
except:
|
88 |
+
raise Exception(f"Missing MSA of family {family_id}")
|
89 |
+
|
90 |
+
# align sequences
|
91 |
+
sequences = list(sequence_df["sequence"])
|
92 |
+
scores = align_and_score_sequences_in_a3m_with_hmm(hmm, sequences_list=sequences)
|
93 |
+
|
94 |
+
# save the scores associated to each sequence in the main df in the columns "score" and "evalue"
|
95 |
+
for seq in tqdm(sequences):
|
96 |
+
sequence_df.loc[sequence_df["sequence"] == seq, "score_gen"] = scores[seq]["score"] if seq in scores.keys() else 0
|
97 |
+
sequence_df.loc[sequence_df["sequence"] == seq, "evalue_gen"] = scores[seq]["evalue"] if seq in scores.keys() else 1
|
98 |
+
|
99 |
+
return sequence_df
|
100 |
+
|
101 |
+
|
102 |
+
|
protxlstm/applications/generation_utils/score_structure.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from Bio.PDB import PDBParser
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
from transformers import EsmForProteinFolding
|
5 |
+
|
6 |
+
from protxlstm.utils import MASK_TO_ID
|
7 |
+
|
8 |
+
|
9 |
+
pdb_parser = PDBParser()
|
10 |
+
|
11 |
+
|
12 |
+
def compute_structure(seq, model):
|
13 |
+
def keep_sequence(seq, l):
|
14 |
+
if len(seq) > l:
|
15 |
+
return False
|
16 |
+
for mm in list(MASK_TO_ID.keys())+["<eos>", "<pad>", "<unk>", "<mask>", "<cls>", "<null_1>", "." , "-"]:
|
17 |
+
if mm in seq:
|
18 |
+
return False
|
19 |
+
return True
|
20 |
+
keep = keep_sequence(seq, l=750)
|
21 |
+
if keep:
|
22 |
+
with torch.no_grad():
|
23 |
+
output = model.infer([seq])
|
24 |
+
# pdb = model.output_to_pdb(output)
|
25 |
+
ptm = output["ptm"].item()
|
26 |
+
pae = output["predicted_aligned_error"].cpu().numpy()
|
27 |
+
mean_plddt = ((output["plddt"] * output["atom37_atom_exists"]).sum(dim=(1, 2)) / output["atom37_atom_exists"].sum(dim=(1, 2))).item()
|
28 |
+
pos_plddt = ((output["plddt"] * output["atom37_atom_exists"]).sum(dim=(2,)) / output["atom37_atom_exists"].sum(dim=(2,))).cpu().numpy()
|
29 |
+
else:
|
30 |
+
print(f"Sequence is invalid.")
|
31 |
+
ptm, pae, mean_plddt, pos_plddt = 0, 0 ,0 , 0
|
32 |
+
return ptm, pae, mean_plddt, pos_plddt
|
33 |
+
|
34 |
+
|
35 |
+
def score_structure(sequence_df, family_idx):
|
36 |
+
|
37 |
+
assert len(set(list(sequence_df["family"]))) == 1 and sequence_df["family"].iloc[0] == family_idx
|
38 |
+
|
39 |
+
device="cuda:0"
|
40 |
+
|
41 |
+
# Import the folding model
|
42 |
+
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True)
|
43 |
+
|
44 |
+
model = model.cuda(device)
|
45 |
+
model.esm = model.esm.half()
|
46 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
47 |
+
|
48 |
+
sequences = list(sequence_df["sequence"])
|
49 |
+
for seq in tqdm(sequences):
|
50 |
+
|
51 |
+
ptm, pae, mean_plddt, pos_plddt = compute_structure(seq, model)
|
52 |
+
sequence_df.loc[sequence_df["sequence"] == seq, "ptm"] = ptm
|
53 |
+
sequence_df.loc[sequence_df["sequence"] == seq, "mean_plddt"] = mean_plddt
|
54 |
+
|
55 |
+
return sequence_df
|
protxlstm/applications/msa_sampler.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Original code from ProtMamba under Apache License 2.0.
|
2 |
+
#
|
3 |
+
# Modifications made by Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen
|
4 |
+
# - Modify handling of weights in `MSASampler`
|
5 |
+
|
6 |
+
import math
|
7 |
+
import os
|
8 |
+
from typing import Optional, Callable
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from protxlstm.utils import AA_TO_ID
|
14 |
+
|
15 |
+
|
16 |
+
def compute_hamming_csim_torch(
|
17 |
+
seqs: torch.Tensor,
|
18 |
+
ungapped_msa: torch.Tensor,
|
19 |
+
gap_token: int,
|
20 |
+
gap_token_mask: int,
|
21 |
+
) -> torch.Tensor:
|
22 |
+
return (seqs.unsqueeze(1) == ungapped_msa).sum(dim=2)
|
23 |
+
|
24 |
+
def _compute_homology_weights(
|
25 |
+
ungapped_msa: np.ndarray,
|
26 |
+
gap_token: int,
|
27 |
+
gap_token_mask: int,
|
28 |
+
theta: float,
|
29 |
+
hamming_csim_func: Callable,
|
30 |
+
max_memory: int = 20,
|
31 |
+
can_use_torch: bool = True,
|
32 |
+
) -> np.ndarray:
|
33 |
+
use_torch = can_use_torch
|
34 |
+
if use_torch:
|
35 |
+
hamming_csim_func = compute_hamming_csim_torch
|
36 |
+
batch_size = math.floor(
|
37 |
+
2
|
38 |
+
* 1024
|
39 |
+
* 1024
|
40 |
+
* 1024
|
41 |
+
/ (ungapped_msa.shape[0] * ungapped_msa.shape[1])
|
42 |
+
* max_memory
|
43 |
+
/ 40
|
44 |
+
)
|
45 |
+
|
46 |
+
batch_size = 1 if batch_size == 0 else batch_size
|
47 |
+
|
48 |
+
neighbors = []
|
49 |
+
if not use_torch:
|
50 |
+
masked_ungapped_msa = ungapped_msa.copy()
|
51 |
+
else:
|
52 |
+
ungapped_msa = torch.from_numpy(ungapped_msa).byte()
|
53 |
+
masked_ungapped_msa = ungapped_msa.clone()
|
54 |
+
masked_ungapped_msa[masked_ungapped_msa == gap_token] = gap_token_mask
|
55 |
+
for b_start in range(0, len(ungapped_msa), batch_size):
|
56 |
+
b_end = b_start + batch_size
|
57 |
+
seqs = ungapped_msa[b_start:b_end]
|
58 |
+
|
59 |
+
sim = hamming_csim_func(
|
60 |
+
seqs=seqs,
|
61 |
+
ungapped_msa=masked_ungapped_msa,
|
62 |
+
gap_token=gap_token,
|
63 |
+
gap_token_mask=gap_token_mask,
|
64 |
+
)
|
65 |
+
if not use_torch:
|
66 |
+
sim = sim / (seqs != gap_token).sum(axis=1, keepdims=True)
|
67 |
+
d = 1 - sim
|
68 |
+
d = d.clamp(0, 1)
|
69 |
+
this_neighbors = (d <= theta).sum(axis=1)
|
70 |
+
else:
|
71 |
+
sim = sim / (seqs != gap_token).sum(dim=1, keepdim=True)
|
72 |
+
d = 1 - sim
|
73 |
+
# fillna
|
74 |
+
d[torch.isnan(d)] = 0
|
75 |
+
d = d.clamp(0, 1)
|
76 |
+
this_neighbors = (d <= theta).sum(dim=1).cpu()
|
77 |
+
neighbors.append(this_neighbors)
|
78 |
+
return np.concatenate(neighbors)
|
79 |
+
|
80 |
+
def compute_homology_weights(
|
81 |
+
ungapped_msa: np.ndarray,
|
82 |
+
theta: float = 0.2,
|
83 |
+
gap_token: int = AA_TO_ID["-"],
|
84 |
+
gap_token_mask: int = 255,
|
85 |
+
hamming_csim_func: Callable = compute_hamming_csim_torch,
|
86 |
+
) -> tuple[int, np.ndarray]:
|
87 |
+
"""
|
88 |
+
Calculate the effective number of sequences and sampling probability for the NEIGHBORS and NEIGHBORS_NO_LIMIT sampling methods using numpy.
|
89 |
+
|
90 |
+
Parameters:
|
91 |
+
|
92 |
+
ungapped_msa (np.ndarray): The MSA (from .fa).
|
93 |
+
theta (float, optional): A parameter used to determine the similarity between sequences. Default is 0.2.
|
94 |
+
gap_token (int, optional): The token representing gaps in the (Uniprot21 encoded) MSA. Default is 20.
|
95 |
+
gap_token_mask (int): token for masking gaps. should be a token not representing any other value.
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
|
99 |
+
tuple[int, np.ndarray]: A tuple containing the effective number of sequences and the sampling probability for each sequence in the MSA.
|
100 |
+
"""
|
101 |
+
neighbors = _compute_homology_weights(
|
102 |
+
ungapped_msa=ungapped_msa,
|
103 |
+
gap_token=gap_token,
|
104 |
+
gap_token_mask=gap_token_mask,
|
105 |
+
theta=theta,
|
106 |
+
hamming_csim_func=hamming_csim_func,
|
107 |
+
)
|
108 |
+
n_eff = np.sum(1 / neighbors)
|
109 |
+
|
110 |
+
p = 1 / neighbors
|
111 |
+
p /= np.sum(p)
|
112 |
+
return n_eff, p
|
113 |
+
|
114 |
+
class MSASampler:
|
115 |
+
|
116 |
+
def __init__(self, max_similarity, max_dissimilarity, force_include_first=True):
|
117 |
+
self.max_similarity = max_similarity
|
118 |
+
self.max_dissimilarity = max_dissimilarity
|
119 |
+
self.force_include_first = force_include_first
|
120 |
+
self.theta = 0.2
|
121 |
+
|
122 |
+
def _get_sim_filtered_idxs(self, msa: np.ndarray) -> np.ndarray:
|
123 |
+
nonnormalized_sim = (msa == msa[[0]]).sum(axis=1)
|
124 |
+
normfactor = msa.shape[1]
|
125 |
+
norm_sim = nonnormalized_sim / normfactor
|
126 |
+
|
127 |
+
assert (norm_sim.min() >= 0) and (norm_sim.max() <= 1)
|
128 |
+
dsim = 1 - norm_sim
|
129 |
+
|
130 |
+
max_sim_filter = norm_sim <= self.max_similarity
|
131 |
+
max_dissim_filter = dsim <= self.max_dissimilarity
|
132 |
+
return np.where(max_sim_filter & max_dissim_filter)[0]
|
133 |
+
|
134 |
+
def get_weights(
|
135 |
+
self, msa: np.ndarray,
|
136 |
+
) -> tuple[Optional[float], Optional[np.ndarray]]:
|
137 |
+
return compute_homology_weights(
|
138 |
+
ungapped_msa=msa,
|
139 |
+
theta=self.theta,
|
140 |
+
gap_token_mask=255,
|
141 |
+
|
142 |
+
)
|
143 |
+
|
144 |
+
def get_sample_idxs(
|
145 |
+
self,
|
146 |
+
msa: np.ndarray,
|
147 |
+
size: int = 1,
|
148 |
+
random = False,
|
149 |
+
msa_weights_path = None,
|
150 |
+
seed = 0,
|
151 |
+
) -> np.ndarray:
|
152 |
+
|
153 |
+
np.random.seed(seed)
|
154 |
+
|
155 |
+
if random:
|
156 |
+
return np.random.choice(len(msa), replace=False, size=size) if len(msa) >= size else np.arange(len(msa))
|
157 |
+
|
158 |
+
msa = np.array([[AA_TO_ID[aa] for aa in seq.upper()][:len(msa[0])] for seq in msa], dtype=np.uint8)
|
159 |
+
|
160 |
+
if msa_weights_path and os.path.exists(msa_weights_path):
|
161 |
+
weights = np.load(msa_weights_path)
|
162 |
+
elif msa_weights_path:
|
163 |
+
os.makedirs(os.path.dirname(msa_weights_path), exist_ok=True)
|
164 |
+
_, weights = self.get_weights(
|
165 |
+
msa=msa,
|
166 |
+
)
|
167 |
+
np.save(msa_weights_path, weights)
|
168 |
+
else:
|
169 |
+
_, weights = self.get_weights(
|
170 |
+
msa=msa,
|
171 |
+
)
|
172 |
+
|
173 |
+
|
174 |
+
original_msa_sample_idxs = np.arange(len(msa))
|
175 |
+
sample_idxs = self._get_sim_filtered_idxs(msa)
|
176 |
+
original_msa_sample_idxs = original_msa_sample_idxs[sample_idxs]
|
177 |
+
|
178 |
+
if self.force_include_first:
|
179 |
+
original_msa_sample_idxs = np.concatenate(
|
180 |
+
[[0], original_msa_sample_idxs[original_msa_sample_idxs != 0]]
|
181 |
+
)
|
182 |
+
return np.random.choice(len(msa), replace=False, size=size, p=weights / weights.sum()) if len(msa) >= size else original_msa_sample_idxs
|
183 |
+
|
184 |
+
def sample_msa(msa_sequences, msa_weights_path=None, context_length=200_000, max_context_sequences=200, seed=0, sort=True):
|
185 |
+
"""Sample MSA sequences for the context"""
|
186 |
+
n_sequences = min( context_length // len(msa_sequences[0]), len(msa_sequences) if max_context_sequences == 0 else max_context_sequences ) - 1
|
187 |
+
sampler = MSASampler(0.98, 0.7, force_include_first=False)
|
188 |
+
sample_idx = sampler.get_sample_idxs(
|
189 |
+
msa_sequences, size=n_sequences, msa_weights_path=msa_weights_path, seed=seed
|
190 |
+
)
|
191 |
+
|
192 |
+
# Sort sequences from least similar to most similar and add wild type target sequence
|
193 |
+
if sort:
|
194 |
+
context_sequences = [msa_sequences[i] for i in sample_idx][::-1]
|
195 |
+
|
196 |
+
return context_sequences
|
protxlstm/applications/sample_sequences.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from tqdm import tqdm
|
3 |
+
import pickle
|
4 |
+
import os
|
5 |
+
import argparse
|
6 |
+
import json
|
7 |
+
|
8 |
+
from protxlstm.dataloaders import ProteinMemmapDataset
|
9 |
+
from protxlstm.generation import generate_sequence
|
10 |
+
from protxlstm.utils import (
|
11 |
+
AA_TO_ID,
|
12 |
+
load_model,
|
13 |
+
)
|
14 |
+
from protxlstm.models.xlstm import xLSTMLMHeadModel
|
15 |
+
from protxlstm.models.mamba import MambaLMHeadModelwithPosids
|
16 |
+
|
17 |
+
|
18 |
+
def sample_sequences(dataset,
|
19 |
+
model,
|
20 |
+
family_idx,
|
21 |
+
params,
|
22 |
+
n_samples_per_family,
|
23 |
+
max_length=1000,
|
24 |
+
chunk_chunk_size=2**15,
|
25 |
+
save_path=None,
|
26 |
+
device="cuda:0"):
|
27 |
+
"""
|
28 |
+
Function to sample sequences from the model. Given a dataset, a list of families (their indexes in the dataset)
|
29 |
+
and a set of generating parameters, it generates `n_samples_per_family` sequences for each family and each parameter set.
|
30 |
+
The function returns a dictionary with the following structure:
|
31 |
+
gen_seqs = {family_idx: {parameters: {sequence: perplexity}}}
|
32 |
+
The parameters are in a list of tuples with the following structure:
|
33 |
+
parameters_list = [(nr_seqs_ctx, temperature, top_k, top_p)]
|
34 |
+
"""
|
35 |
+
gen_seqs = {}
|
36 |
+
gen_seqs[family_idx] = {}
|
37 |
+
gen_seqs[family_idx][params] = {}
|
38 |
+
print(f"Sampling sequences for family {family_idx} and parameters {params}.")
|
39 |
+
|
40 |
+
n_seqs_ctx , temperature, top_k, top_p = params
|
41 |
+
for _ in tqdm(range(n_samples_per_family)):
|
42 |
+
# Sample the dataset to get the input
|
43 |
+
data = dataset[family_idx]
|
44 |
+
tokens = data["input_ids"][None,:].to(device)
|
45 |
+
pos_ids = data["position_ids"][None,:].to(device)
|
46 |
+
|
47 |
+
start_seqs = torch.argwhere(tokens[0]==0)[:,0].cpu().numpy()
|
48 |
+
|
49 |
+
n_seqs_ctx = len(start_seqs) if len(start_seqs) < n_seqs_ctx else n_seqs_ctx
|
50 |
+
L = start_seqs[n_seqs_ctx]+1
|
51 |
+
context_tokens = tokens[:,:L]
|
52 |
+
context_pos_ids = pos_ids[:,:L]
|
53 |
+
is_fim={}
|
54 |
+
|
55 |
+
# Generate the new sequence
|
56 |
+
output = generate_sequence(model,
|
57 |
+
context_tokens,
|
58 |
+
position_ids=context_pos_ids,
|
59 |
+
is_fim=is_fim,
|
60 |
+
max_length=(L+max_length),
|
61 |
+
temperature=temperature,
|
62 |
+
top_k=top_k,
|
63 |
+
top_p=top_p,
|
64 |
+
return_dict_in_generate=True,
|
65 |
+
output_scores=True,
|
66 |
+
eos_token_id=torch.tensor([AA_TO_ID["<cls>"]]).to(device),
|
67 |
+
chunk_chunk_size=chunk_chunk_size,
|
68 |
+
device=device)
|
69 |
+
|
70 |
+
# Get the perplexity of the generated sequence
|
71 |
+
output_seq = output["generated"]
|
72 |
+
loss = torch.nn.functional.cross_entropy(torch.from_numpy(output["scores"]).permute(0, 2, 1),
|
73 |
+
torch.from_numpy(output["generated_tokens"][0][None,:]))
|
74 |
+
|
75 |
+
# save only sequences with length < max_length
|
76 |
+
if len(output_seq[0]) < max_length:
|
77 |
+
|
78 |
+
gen_seqs[family_idx][params][output_seq[0]] = {"perplexity": torch.exp(loss).item()}
|
79 |
+
|
80 |
+
if save_path is not None:
|
81 |
+
if not os.path.exists("evaluation/generation/generated_sequences"):
|
82 |
+
os.mkdir("evaluation/generation/generated_sequences")
|
83 |
+
if not os.path.exists(save_path):
|
84 |
+
os.mkdir(save_path)
|
85 |
+
with open(f'{save_path}/{family_idx}_{params}_{n_samples_per_family}', "wb") as f:
|
86 |
+
pickle.dump(gen_seqs, f)
|
87 |
+
print(f"Sequences saved for family {family_idx} and parameters {params}")
|
88 |
+
|
89 |
+
return gen_seqs
|
90 |
+
|
91 |
+
def generate_sequences(model_name,
|
92 |
+
checkpoint,
|
93 |
+
family_idxs=[],
|
94 |
+
parameters_list=[],
|
95 |
+
n_samples_per_family = 100,
|
96 |
+
chunk_size=1024,
|
97 |
+
chunk_chunk_size=2**15,
|
98 |
+
data_dir="data/",
|
99 |
+
device="cuda:0"
|
100 |
+
):
|
101 |
+
|
102 |
+
# Load the test dataset
|
103 |
+
fim_strategy = "multiple_span"
|
104 |
+
mask_fraction = 0.2
|
105 |
+
|
106 |
+
dataset = ProteinMemmapDataset(
|
107 |
+
msa_memmap_path=f"{data_dir}open_protein_set_memmap.dat",
|
108 |
+
msa_memmap_meta_path=f"{data_dir}open_protein_set_memmap_indices.csv",
|
109 |
+
subset_path=f"{data_dir}cluster_testing_set.txt",
|
110 |
+
sample=False,
|
111 |
+
max_msa_len=-1,
|
112 |
+
reverse=False,
|
113 |
+
seed=0,
|
114 |
+
troubleshoot=False,
|
115 |
+
fim_strategy=fim_strategy,
|
116 |
+
always_mask=False,
|
117 |
+
max_position_embeddings=2048,
|
118 |
+
max_seq_position_embeddings=512,
|
119 |
+
add_position_ids="1d",
|
120 |
+
mask_fraction=mask_fraction
|
121 |
+
)
|
122 |
+
|
123 |
+
if model_name == "xlstm":
|
124 |
+
model_class = xLSTMLMHeadModel
|
125 |
+
elif model_name == "mamba":
|
126 |
+
model_class = MambaLMHeadModelwithPosids
|
127 |
+
|
128 |
+
save_path = f"evaluation/generation/generated_sequences/{checkpoint.split('/')[-1]}"
|
129 |
+
|
130 |
+
if model_name == "xlstm":
|
131 |
+
config_update_kwargs = {
|
132 |
+
"mlstm_backend": "chunkwise_variable",
|
133 |
+
"mlstm_chunksize": chunk_size,
|
134 |
+
"mlstm_return_last_state": True
|
135 |
+
}
|
136 |
+
else:
|
137 |
+
config_update_kwargs = {}
|
138 |
+
|
139 |
+
|
140 |
+
#load the model
|
141 |
+
model = load_model(checkpoint,
|
142 |
+
model_class=model_class,
|
143 |
+
device=device,
|
144 |
+
dtype=torch.bfloat16,
|
145 |
+
**config_update_kwargs,
|
146 |
+
)
|
147 |
+
model = model.eval()
|
148 |
+
print("Model loaded.")
|
149 |
+
|
150 |
+
for family_idx in family_idxs:
|
151 |
+
for params in parameters_list:
|
152 |
+
params = tuple(params)
|
153 |
+
if not os.path.exists(f'{save_path}/{family_idx}_{params}_{n_samples_per_family}'):
|
154 |
+
gen_seqs = sample_sequences(
|
155 |
+
dataset=dataset,
|
156 |
+
model=model,
|
157 |
+
family_idx=family_idx,
|
158 |
+
params=params,
|
159 |
+
n_samples_per_family=n_samples_per_family,
|
160 |
+
chunk_chunk_size=chunk_chunk_size,
|
161 |
+
save_path=save_path,
|
162 |
+
device=device)
|
163 |
+
|
164 |
+
print(f"Sampled {len(gen_seqs[family_idx][params])} valid sequences.")
|
165 |
+
else:
|
166 |
+
print(f"Sequences for family {family_idx} and parameters {params} already exist.")
|
167 |
+
|
168 |
+
|
169 |
+
if __name__ == "__main__":
|
170 |
+
|
171 |
+
parser = argparse.ArgumentParser(
|
172 |
+
description="Generate sequences."
|
173 |
+
)
|
174 |
+
parser.add_argument("--model_name", type=str, help="Either 'xlstm' or 'mamba'.")
|
175 |
+
parser.add_argument("--checkpoint", type=str, help="Path to model checkpoint.")
|
176 |
+
parser.add_argument("--family_idxs", type=str, help="List of family indices.")
|
177 |
+
parser.add_argument("--parameters_list", type=str, help="List of sampling parameters.")
|
178 |
+
parser.add_argument("--n_samples_per_family", type=int, default=100, help="Number of sequences to sample per family and parameter set.")
|
179 |
+
parser.add_argument("--chunk_size", type=int, default=1024, help="Chunk size for xLSTM context encoding.")
|
180 |
+
parser.add_argument("--chunk_chunk_size", type=int, default=2*15, help="Length of context sequence part processed at once.")
|
181 |
+
parser.add_argument("--data_dir", type=str, default="data/", help="Path to dataset.")
|
182 |
+
parser.add_argument("--device", type=str, default="cuda:0", help="Device.")
|
183 |
+
|
184 |
+
args = parser.parse_args()
|
185 |
+
|
186 |
+
family_idxs = json.loads(args.family_idxs)
|
187 |
+
parameters_list = json.loads(args.parameters_list)
|
188 |
+
|
189 |
+
# Run sequence generation
|
190 |
+
generate_sequences(
|
191 |
+
model_name=args.model_name,
|
192 |
+
checkpoint=args.checkpoint,
|
193 |
+
family_idxs=family_idxs,
|
194 |
+
parameters_list=parameters_list,
|
195 |
+
n_samples_per_family=args.n_samples_per_family,
|
196 |
+
chunk_size=args.chunk_size,
|
197 |
+
chunk_chunk_size=args.chunk_chunk_size,
|
198 |
+
data_dir=args.data_dir,
|
199 |
+
device=args.device,
|
200 |
+
)
|
protxlstm/applications/score_sequences.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import pickle
|
4 |
+
|
5 |
+
from generation_utils.create_sequence_df import create_sequence_df
|
6 |
+
from generation_utils.score_hamming import score_hamming
|
7 |
+
from generation_utils.score_hmmer import score_hmmer
|
8 |
+
from generation_utils.score_structure import score_structure
|
9 |
+
|
10 |
+
|
11 |
+
def score_sequences(model_name,
|
12 |
+
family_idx,
|
13 |
+
num_sequences = 100,
|
14 |
+
data_dir = "data/"):
|
15 |
+
|
16 |
+
if os.path.isfile(f"evaluation/generation/evaluations/{model_name}/sequence_df_{family_idx}"):
|
17 |
+
with open(f"evaluation/generation/evaluations/{model_name}/sequence_df_{family_idx}", "rb") as f:
|
18 |
+
sequence_df = pickle.load(f)
|
19 |
+
else:
|
20 |
+
sequence_df = create_sequence_df(model_name, family_idx, data_dir = data_dir, num_sequences = num_sequences)
|
21 |
+
if not os.path.exists("evaluation/generation/evaluations/"):
|
22 |
+
os.mkdir("evaluation/generation/evaluations/")
|
23 |
+
if not os.path.exists(f"evaluation/generation/evaluations/{model_name}/"):
|
24 |
+
os.mkdir(f"evaluation/generation/evaluations/{model_name}/")
|
25 |
+
with open(f"evaluation/generation/evaluations/{model_name}/sequence_df_{family_idx}", "wb") as f:
|
26 |
+
pickle.dump(sequence_df, f)
|
27 |
+
|
28 |
+
if not "min_hamming" in sequence_df.columns:
|
29 |
+
sequence_df = score_hamming(sequence_df, family_idx, data_dir)
|
30 |
+
with open(f"evaluation/generation/evaluations/{model_name}/sequence_df_{family_idx}", "wb") as f:
|
31 |
+
pickle.dump(sequence_df, f)
|
32 |
+
|
33 |
+
if not "score_gen" in sequence_df.columns:
|
34 |
+
sequence_df = score_hmmer(sequence_df, family_idx, data_dir)
|
35 |
+
with open(f"evaluation/generation/evaluations/{model_name}/sequence_df_{family_idx}", "wb") as f:
|
36 |
+
pickle.dump(sequence_df, f)
|
37 |
+
|
38 |
+
if not "ptm" in sequence_df.columns:
|
39 |
+
sequence_df = score_structure(sequence_df, family_idx)
|
40 |
+
with open(f"evaluation/generation/evaluations/{model_name}/sequence_df_{family_idx}", "wb") as f:
|
41 |
+
pickle.dump(sequence_df, f)
|
42 |
+
|
43 |
+
return sequence_df
|
44 |
+
|
45 |
+
|
46 |
+
if __name__ == "__main__":
|
47 |
+
|
48 |
+
parser = argparse.ArgumentParser(
|
49 |
+
description="Generate sequences."
|
50 |
+
)
|
51 |
+
parser.add_argument("--model_name", type=str, help="Either 'xlstm' or 'mamba'.")
|
52 |
+
parser.add_argument("--family_idx", type=int, help="Family index.")
|
53 |
+
parser.add_argument("--num_sequences", type=int, default=100, help="Number of sequences.")
|
54 |
+
parser.add_argument("--data_dir", type=str, default="./data/", help="Path to dataset.")
|
55 |
+
|
56 |
+
args = parser.parse_args()
|
57 |
+
|
58 |
+
sequence_df = score_sequences(args.model_name, args.family_idx, args.num_sequences, args.data_dir)
|
protxlstm/checkpoints/small/config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"mlstm_block": {"mlstm": {"proj_factor": 2.0, "round_proj_up_dim_up": true, "round_proj_up_to_multiple_of": 64, "_proj_up_dim": 1024, "conv1d_kernel_size": 4, "qkv_proj_blocksize": 4, "num_heads": 4, "embedding_dim": 512, "bias": false, "dropout": 0.0, "context_length": 2048, "backend": "chunkwise", "chunk_size": 1024, "return_last_state": false, "_num_blocks": 16, "_inner_embedding_dim": 1024}}, "slstm_block": {"slstm": {"hidden_size": 512, "num_heads": 4, "num_states": 4, "backend": "cuda", "function": "slstm", "bias_init": "powerlaw_blockdependent", "recurrent_weight_init": "zeros", "_block_idx": 0, "_num_blocks": 16, "num_gates": 4, "gradient_recurrent_cut": false, "gradient_recurrent_clipval": null, "forward_clipval": null, "batch_size": 8, "input_shape": "BSGNH", "internal_input_shape": "SBNGH", "output_shape": "BNSH", "constants": {}, "dtype": "bfloat16", "dtype_b": "float32", "dtype_r": "bfloat16", "dtype_w": "bfloat16", "dtype_g": "bfloat16", "dtype_s": "bfloat16", "dtype_a": "float32", "enable_automatic_mixed_precision": true, "initial_val": 0.0, "embedding_dim": 512, "conv1d_kernel_size": 4, "dropout": 0.0}, "feedforward": {"proj_factor": 1.3, "round_proj_up_dim_up": true, "round_proj_up_to_multiple_of": 64, "_proj_up_dim": 0, "act_fn": "gelu", "embedding_dim": -1, "dropout": 0.0, "bias": false, "ff_type": "ffn_gated", "_num_blocks": 1}, "_num_blocks": 16, "_block_idx": 0}, "context_length": 2048, "num_blocks": 16, "embedding_dim": 512, "add_post_blocks_norm": true, "bias": false, "dropout": 0.0, "checkpoint_blocks": true, "slstm_at": [], "_block_map": "0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0", "vocab_size": 38, "tie_weights": false, "weight_decay_on_embedding": false, "add_embedding_dropout": false, "position_embeddings": "rot_1d", "max_position_embeddings": 2048, "max_seq_position_embeddings": 512, "rope_base_frequency": 500000}
|
protxlstm/checkpoints/small/optimizer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6bcae4c2a893afed859ca9b5926d24e8f7d0b22eca198e4aa950b80909be8e50
|
3 |
+
size 207533690
|
protxlstm/checkpoints/small/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9ba59f1fd544a5d9f6c4adb40730cf90b8f69a772df838246f724586cb1d602a
|
3 |
+
size 103773526
|
protxlstm/checkpoints/small/rng_state.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5480500274058efdf9caa959d35a42a948ff3dc8536e082b9bc22f2ecd423108
|
3 |
+
size 14244
|
protxlstm/checkpoints/small/scheduler.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:25a26dbf75285e97210697d5608b6d76ef35aa0d2879be319ef2785f881153b9
|
3 |
+
size 1000
|
protxlstm/checkpoints/small/trainer_state.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
protxlstm/data.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import os
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
from protxlstm.utils import load_sequences_from_msa_file, tokenizer
|
8 |
+
|
9 |
+
def process_msa(msa_item):
|
10 |
+
msa_name, msa_path = msa_item
|
11 |
+
# Load an a3m file with all the context sequences
|
12 |
+
msa = load_sequences_from_msa_file(msa_path)
|
13 |
+
# Tokenize the sequences and concatenate them into a single array
|
14 |
+
tokens = tokenizer(msa, concatenate=True)
|
15 |
+
tokens = tokens.numpy()[0]
|
16 |
+
return msa_name, tokens
|
17 |
+
|
18 |
+
def main(data_dir, output_dir):
|
19 |
+
msa_paths = {k: os.path.join(data_dir, k, 'a3m/uniclust30.a3m') for k in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, k))}
|
20 |
+
msa_items = list(msa_paths.items())
|
21 |
+
|
22 |
+
dataset_dictionary = {}
|
23 |
+
total_length = 0
|
24 |
+
|
25 |
+
# First pass: calculate total length of all concatenated arrays
|
26 |
+
for item in tqdm(msa_items):
|
27 |
+
try:
|
28 |
+
k, v = process_msa(item)
|
29 |
+
dataset_dictionary[k] = v
|
30 |
+
total_length += len(v)
|
31 |
+
except:
|
32 |
+
print(f"Error processing {item}")
|
33 |
+
|
34 |
+
# Initialize the memmap array with the calculated total length
|
35 |
+
memmap_path = os.path.join(output_dir, 'open_protein_set_memmap.dat')
|
36 |
+
concatenated_array = np.memmap(memmap_path, dtype='int8', mode='w+', shape=(total_length,))
|
37 |
+
|
38 |
+
with open(f'{output_dir}/open_protein_set_memmap_indices.csv', 'w', newline='') as csvfile:
|
39 |
+
csvwriter = csv.writer(csvfile)
|
40 |
+
|
41 |
+
csvwriter.writerow(['msa_id', 'Start', 'End'])
|
42 |
+
|
43 |
+
start_index = 0
|
44 |
+
for key, array in dataset_dictionary.items():
|
45 |
+
end_index = start_index + len(array) - 1
|
46 |
+
concatenated_array[start_index:end_index + 1] = array # Write to memmap
|
47 |
+
csvwriter.writerow([key, start_index, end_index])
|
48 |
+
start_index = end_index + 1
|
49 |
+
|
50 |
+
# Ensure the data is written to disk
|
51 |
+
concatenated_array.flush()
|
52 |
+
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
data_dir = 'data/a3m_files'
|
56 |
+
output_dir = 'data/'
|
57 |
+
main(data_dir, output_dir)
|
58 |
+
|
59 |
+
|
60 |
+
|
protxlstm/dataloaders.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Original code from ProtMamba under Apache License 2.0.
|
2 |
+
#
|
3 |
+
# Modifications made by Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen
|
4 |
+
# - Uniclust30_Dataset renamed to ProteinMemmapDataset
|
5 |
+
# - Dataset input file format changed for more efficient dataloading
|
6 |
+
# - Option to use only a subset
|
7 |
+
# - DataCollatorForUniclust30Dataset renamed to ProteinDataCollator
|
8 |
+
# - Add sequence padding
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import pandas as pd
|
12 |
+
import torch
|
13 |
+
from torch.utils.data import DataLoader, Dataset
|
14 |
+
from typing import Dict, Optional, Sequence
|
15 |
+
|
16 |
+
from protxlstm.fim import MultipleSpanFIM, NoFIM, SingleSpanFIM
|
17 |
+
from protxlstm.utils import AA_TO_ID
|
18 |
+
|
19 |
+
|
20 |
+
# Make dataset
|
21 |
+
class ProteinMemmapDataset(Dataset):
|
22 |
+
"""
|
23 |
+
ProteinMemmapDataset is a PyTorch Dataset class for handling memory-mapped datasets of protein multiple sequence alignments (MSAs).
|
24 |
+
|
25 |
+
This class imports MSA data stored in memmap format and associated metadata CSVs. It supports flexible
|
26 |
+
data sampling strategies and inpainting methods for sequence manipulation and training purposes.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
msa_memmap_path (str): Path to the memory-mapped file containing the MSA clusters.
|
30 |
+
msa_memmap_meta_path (str): Path to the CSV file with metadata linking MSA Cluster IDs and indices in the memmap array.
|
31 |
+
subset_path (str, optional): Path to a CSV file specifying a subset of cluster IDs to use.
|
32 |
+
sample (bool, optional): If True, randomly samples sequences from each cluster; otherwise, loads all sequences and shuffles them.
|
33 |
+
max_msa_len (int, optional): Maximum length of the MSA sequences to include. Defaults to -1 (no limit).
|
34 |
+
reverse (bool, optional): If True, reverses sequences with a probability of 0.5 and moves the last token to the front.
|
35 |
+
seed (int, optional): Random seed for reproducibility. Defaults to 42.
|
36 |
+
troubleshoot (bool, optional): If True, prints debugging information. Defaults to False.
|
37 |
+
fim_strategy (str, optional): Strategy for inpainting ("no-scramble", "one_span", or "multiple_span").
|
38 |
+
max_patches (int, optional): Number of patches for inpainting. Used when fim_strategy is "multiple_span".
|
39 |
+
mask_fraction (float, optional): Fraction of the patches to mask. Used when fim_strategy is "multiple_span".
|
40 |
+
always_mask (bool, optional): If True, ensures masking is applied in the inpainting process.
|
41 |
+
max_position_embeddings (int, optional): Maximum position embeddings. Defaults to 2048.
|
42 |
+
max_seq_position_embeddings (int, optional): Maximum sequence position embeddings for 2D positional IDs. Defaults to 512.
|
43 |
+
add_position_ids (str, optional): Type of position IDs to add ("none", "1d", or "2d"). Defaults to "1d".
|
44 |
+
"""
|
45 |
+
|
46 |
+
_FIM = {"no-scramble": NoFIM, "one_span": SingleSpanFIM, "multiple_span": MultipleSpanFIM}
|
47 |
+
_POSIDS = {"none", "1d", "2d"}
|
48 |
+
|
49 |
+
def __init__(self,
|
50 |
+
msa_memmap_path=None,
|
51 |
+
msa_memmap_meta_path=None,
|
52 |
+
subset_path=None,
|
53 |
+
sample=False,
|
54 |
+
max_msa_len=-1,
|
55 |
+
reverse=False,
|
56 |
+
seed=42,
|
57 |
+
troubleshoot=False,
|
58 |
+
fim_strategy="no-scramble",
|
59 |
+
max_patches=5,
|
60 |
+
mask_fraction=0.2,
|
61 |
+
always_mask=False,
|
62 |
+
max_position_embeddings=2048,
|
63 |
+
max_seq_position_embeddings=512,
|
64 |
+
add_position_ids="1d", ):
|
65 |
+
|
66 |
+
np.random.seed(seed)
|
67 |
+
|
68 |
+
if msa_memmap_path:
|
69 |
+
self.dataset = np.memmap(msa_memmap_path, dtype=np.int8, mode='r')
|
70 |
+
self.dataset_meta = pd.read_csv(msa_memmap_meta_path)
|
71 |
+
if subset_path:
|
72 |
+
subset_ids = pd.read_csv(subset_path, header=None, names=['ID'])['ID'].tolist()
|
73 |
+
self.dataset_meta = self.dataset_meta[self.dataset_meta['msa_id'].isin(subset_ids)]
|
74 |
+
else:
|
75 |
+
self.dataset = None
|
76 |
+
|
77 |
+
self.sample = sample
|
78 |
+
self.max_msa_len = max_msa_len
|
79 |
+
self.reverse = reverse
|
80 |
+
self.fim_strategy = fim_strategy
|
81 |
+
if fim_strategy in ProteinMemmapDataset._FIM:
|
82 |
+
self.fim = ProteinMemmapDataset._FIM[fim_strategy](max_patches=max_patches,
|
83 |
+
mask_fraction=mask_fraction,
|
84 |
+
always_mask=always_mask,
|
85 |
+
add_position_ids=add_position_ids != "none",
|
86 |
+
troubleshoot=troubleshoot)
|
87 |
+
else:
|
88 |
+
raise ValueError(f'Fill in the middle stragy "{fim_strategy}" not recognized.')
|
89 |
+
|
90 |
+
self.max_position_embeddings = max_position_embeddings
|
91 |
+
self.max_seq_position_embeddings = max_seq_position_embeddings
|
92 |
+
self.add_position_ids = add_position_ids
|
93 |
+
|
94 |
+
self.troubleshoot = troubleshoot
|
95 |
+
|
96 |
+
def __len__(self):
|
97 |
+
# meta dataframe has one row for each MSA cluster
|
98 |
+
return len(self.dataset_meta)
|
99 |
+
|
100 |
+
def __getitem__(self, idx):
|
101 |
+
# get all the sequences in the cluster
|
102 |
+
sequences = self.get_sequences(idx)
|
103 |
+
# get total number of sequences in the cluster and choose how many to sample
|
104 |
+
orig_num_sequences = len(self.get_index_start_of_sequences(sequences))
|
105 |
+
num_sequences = np.random.randint(1, orig_num_sequences + 1) if self.sample else orig_num_sequences
|
106 |
+
# sample the sequences
|
107 |
+
sequences, position_ids = self.sample_sequences(sequences, num_sequences)
|
108 |
+
# with probability 0.5, reverse the sequences and move the last token to the front
|
109 |
+
sequences, position_ids = self.reverse_sequences(sequences, position_ids) if (
|
110 |
+
self.reverse and np.random.rand() > 0.5) else sequences, position_ids
|
111 |
+
# limit the length of the MSA
|
112 |
+
sequences = sequences[:self.max_msa_len] if self.max_msa_len > 0 else sequences
|
113 |
+
if self.add_position_ids != "none":
|
114 |
+
position_ids = position_ids[:self.max_msa_len] if self.max_msa_len > 0 else position_ids
|
115 |
+
# convert to tensor
|
116 |
+
sequences = torch.asarray(sequences, dtype=torch.int64)
|
117 |
+
position_ids = torch.asarray(position_ids, dtype=torch.int64).clamp(0,
|
118 |
+
self.max_position_embeddings - 1) if self.add_position_ids!="none" else None
|
119 |
+
|
120 |
+
if self.troubleshoot:
|
121 |
+
print(
|
122 |
+
f"Cluster {idx} has {orig_num_sequences} sequences, of which {num_sequences} sampled now. Total MSA length: {len(sequences)}")
|
123 |
+
if self.add_position_ids == "1d":
|
124 |
+
return dict(input_ids=sequences, position_ids=position_ids, labels=sequences)
|
125 |
+
if self.add_position_ids == "2d":
|
126 |
+
seq_position_ids = (sequences == AA_TO_ID["<cls>"]).int().cumsum(-1).clamp(0,
|
127 |
+
self.max_seq_position_embeddings - 1).contiguous()
|
128 |
+
return dict(input_ids=sequences, position_ids=position_ids, seq_position_ids=seq_position_ids,
|
129 |
+
labels=sequences)
|
130 |
+
return dict(input_ids=sequences, labels=sequences)
|
131 |
+
|
132 |
+
def get_msa_id(self, idx):
|
133 |
+
"""Get the MSA ID in the cluster with index `idx`."""
|
134 |
+
cluster_meta = self.dataset_meta.iloc[idx]
|
135 |
+
return cluster_meta.msa_id
|
136 |
+
|
137 |
+
def get_idx_from_msa_id(self, msa_id):
|
138 |
+
"""Get `idx` with the MSA ID"""
|
139 |
+
return self.dataset_meta[self.dataset_meta.msa_id == msa_id].index[0]
|
140 |
+
|
141 |
+
def get_sequences(self, idx):
|
142 |
+
"""Get the sequences in the cluster with index `idx`."""
|
143 |
+
cluster_meta = self.dataset_meta.iloc[idx]
|
144 |
+
sequences = self.dataset[cluster_meta.Start : cluster_meta.End]
|
145 |
+
return sequences
|
146 |
+
|
147 |
+
def get_index_start_of_sequences(self, sequences):
|
148 |
+
"""Get the positions of the start of each sequence in the cluster."""
|
149 |
+
return np.where(sequences == 0)[0]
|
150 |
+
|
151 |
+
def reverse_sequences(self, sequence, position_ids=None):
|
152 |
+
"""Reverse the sequences and move the last token to the front."""
|
153 |
+
sequence = sequence[::-1]
|
154 |
+
if position_ids is not None:
|
155 |
+
position_ids = position_ids[::-1]
|
156 |
+
return np.concatenate([sequence[-1:], sequence[:-1]]), np.concatenate(
|
157 |
+
[position_ids[-1:], position_ids[:-1]]) if position_ids is not None else None
|
158 |
+
|
159 |
+
def sample_sequences(self, sequences, num_sequences, shuffle=True):
|
160 |
+
"""Sample `num_sequences` from the sequences in the cluster."""
|
161 |
+
L = len(sequences)
|
162 |
+
# get the indexes of the start of each sequence
|
163 |
+
inds = self.get_index_start_of_sequences(sequences)
|
164 |
+
# check that there are sequences in the cluster and that there are enough of them
|
165 |
+
assert len(inds) > 0, "No sequences found in cluster."
|
166 |
+
assert len(inds) >= num_sequences, "Not enough sequences in cluster."
|
167 |
+
# sample n_sequences randomly from the sequences
|
168 |
+
if shuffle:
|
169 |
+
which_seqs = np.random.choice(np.arange(len(inds)), num_sequences, replace=False)
|
170 |
+
else:
|
171 |
+
which_seqs = np.arange(len(inds))[-num_sequences:]
|
172 |
+
# get the tuples of start and end indexes of the sequences
|
173 |
+
tuples = [(inds[i], inds[i + 1]) if i < len(inds) - 1 else (inds[i], L) for i in which_seqs]
|
174 |
+
if self.troubleshoot:
|
175 |
+
print(f"Sampled sequences: {tuples}")
|
176 |
+
# concatenate the sequences
|
177 |
+
sequences, position_ids = self.fim.apply(sequences, tuples)
|
178 |
+
return sequences, position_ids
|
179 |
+
|
180 |
+
|
181 |
+
|
182 |
+
def make_dataloader(dataset):
|
183 |
+
"""Basic function to make a dataloader.
|
184 |
+
"""
|
185 |
+
dataloader = DataLoader(dataset)
|
186 |
+
return dataloader
|
187 |
+
|
188 |
+
|
189 |
+
class ProteinDataCollator(object):
|
190 |
+
"""
|
191 |
+
Collate examples into a batch, and pad batch to a specified maximum sequence length,
|
192 |
+
or to the longest sequence in the batch if max_sequence_length is None.
|
193 |
+
"""
|
194 |
+
def __init__(self, max_sequence_length: Optional[int] = None):
|
195 |
+
"""
|
196 |
+
Initialize the collator with an optional max_sequence_length.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
max_sequence_length (Optional[int]): The maximum sequence length to pad/truncate to.
|
200 |
+
If None, pad to the longest sequence in the batch.
|
201 |
+
"""
|
202 |
+
self.max_sequence_length = max_sequence_length
|
203 |
+
|
204 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
205 |
+
|
206 |
+
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "input_ids"))
|
207 |
+
|
208 |
+
longest_seq = max(len(seq) for seq in input_ids)
|
209 |
+
if self.max_sequence_length is None:
|
210 |
+
max_len = longest_seq
|
211 |
+
else:
|
212 |
+
max_len = self.max_sequence_length
|
213 |
+
|
214 |
+
input_ids = self.pad_sequences(input_ids, max_len, padding_value=AA_TO_ID["<pad>"])
|
215 |
+
|
216 |
+
labels = self.pad_sequences(labels, longest_seq, padding_value=AA_TO_ID["<pad>"])
|
217 |
+
labels = self.pad_sequences(labels, max_len, padding_value=-100)
|
218 |
+
|
219 |
+
return_dict = dict(
|
220 |
+
input_ids=input_ids,
|
221 |
+
labels=labels,
|
222 |
+
attention_mask=input_ids.ne(AA_TO_ID["<pad>"])
|
223 |
+
)
|
224 |
+
|
225 |
+
if "position_ids" in instances[0]:
|
226 |
+
|
227 |
+
position_ids = [instance["position_ids"] for instance in instances]
|
228 |
+
position_ids = self.pad_sequences(position_ids, max_len, padding_value=0)
|
229 |
+
return_dict["position_ids"] = position_ids
|
230 |
+
|
231 |
+
if "seq_position_ids" in instances[0]:
|
232 |
+
seq_position_ids = [instance["seq_position_ids"] for instance in instances]
|
233 |
+
seq_position_ids = self.pad_sequences(seq_position_ids, max_len, padding_value=0)
|
234 |
+
return_dict["seq_position_ids"] = seq_position_ids
|
235 |
+
|
236 |
+
return return_dict
|
237 |
+
|
238 |
+
def pad_sequences(self, seqs, max_length, padding_value):
|
239 |
+
# truncate long sequences (redundant, already done in __getitem__, maybe safe to remove)
|
240 |
+
seqs = [seq[:max_length] for seq in seqs]
|
241 |
+
|
242 |
+
# pad to same length
|
243 |
+
seqs = torch.nn.utils.rnn.pad_sequence(seqs, batch_first=True, padding_value=padding_value)
|
244 |
+
|
245 |
+
# pad to max length
|
246 |
+
padding = max_length - seqs.size(1)
|
247 |
+
seqs = torch.nn.functional.pad(seqs, (0, padding), value=padding_value)
|
248 |
+
|
249 |
+
return seqs
|
protxlstm/fim.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Original code from ProtMamba under Apache License 2.0.
|
3 |
+
|
4 |
+
from protxlstm.utils import MASK_TO_ID, AA_TO_ID
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
class AbstractFIM(object):
|
8 |
+
def __init__(self,
|
9 |
+
max_patches=5,
|
10 |
+
mask_fraction=0.2,
|
11 |
+
always_mask=False,
|
12 |
+
mask_tokens=MASK_TO_ID,
|
13 |
+
eos_token=AA_TO_ID["<eos>"],
|
14 |
+
add_position_ids=False,
|
15 |
+
troubleshoot=False):
|
16 |
+
"""
|
17 |
+
This class is designed to concatenate sequences based on different scrambling strategies.
|
18 |
+
It takes a list of sequences, tuples indicating the start and end indices of each sequence,
|
19 |
+
an optional number of patches to sample, and a scrambling strategy as inputs.
|
20 |
+
"""
|
21 |
+
self.troubleshoot = troubleshoot
|
22 |
+
self.max_patches = max_patches
|
23 |
+
self.mask_fraction = mask_fraction
|
24 |
+
self.mask_tokens = mask_tokens
|
25 |
+
assert len(
|
26 |
+
self.mask_tokens) >= self.max_patches, "Number of mask tokens must be bigger than max number of patches."
|
27 |
+
self.eos_token = eos_token
|
28 |
+
self.add_position_ids = add_position_ids
|
29 |
+
self.always_mask = always_mask
|
30 |
+
|
31 |
+
def apply(self, sequences, tuples):
|
32 |
+
"""
|
33 |
+
This function concatenates the sequences scrambling each one according to the scrambling strategy.
|
34 |
+
"""
|
35 |
+
input_ids, position_ids = [], []
|
36 |
+
for t in tuples:
|
37 |
+
seq, pos = self.fim(sequences, t)
|
38 |
+
input_ids.extend(seq)
|
39 |
+
if self.add_position_ids:
|
40 |
+
position_ids.extend(pos)
|
41 |
+
if self.add_position_ids:
|
42 |
+
return input_ids, position_ids
|
43 |
+
return input_ids, None
|
44 |
+
|
45 |
+
def fim(self, sequences, t):
|
46 |
+
"""
|
47 |
+
This function concatenates the sequence's parts based on the scrambling strategy.
|
48 |
+
"""
|
49 |
+
raise NotImplementedError
|
50 |
+
|
51 |
+
|
52 |
+
class NoFIM(AbstractFIM):
|
53 |
+
def __init__(self,
|
54 |
+
max_patches=5,
|
55 |
+
mask_fraction=0.2,
|
56 |
+
always_mask=False,
|
57 |
+
mask_tokens=MASK_TO_ID,
|
58 |
+
eos_token=AA_TO_ID["<eos>"],
|
59 |
+
add_position_ids=False,
|
60 |
+
troubleshoot=False):
|
61 |
+
super().__init__(max_patches, mask_fraction, always_mask, mask_tokens, eos_token, add_position_ids, troubleshoot)
|
62 |
+
|
63 |
+
def fim(self, sequences, t):
|
64 |
+
"""
|
65 |
+
This function keeps the sequence identical without any scrambling.
|
66 |
+
"""
|
67 |
+
if self.add_position_ids:
|
68 |
+
position_ids = np.arange(t[0], t[1]) - t[0]
|
69 |
+
return sequences[t[0]:t[1]], position_ids
|
70 |
+
return sequences[t[0]:t[1]], None
|
71 |
+
|
72 |
+
|
73 |
+
class SingleSpanFIM(AbstractFIM):
|
74 |
+
|
75 |
+
def __init__(self,
|
76 |
+
max_patches=5,
|
77 |
+
mask_fraction=0.2,
|
78 |
+
always_mask=False,
|
79 |
+
mask_tokens=MASK_TO_ID,
|
80 |
+
eos_token=AA_TO_ID["<eos>"],
|
81 |
+
add_position_ids=False,
|
82 |
+
troubleshoot=False):
|
83 |
+
super().__init__(max_patches, mask_fraction, always_mask, mask_tokens, eos_token, add_position_ids, troubleshoot)
|
84 |
+
|
85 |
+
def fim(self, sequences, t):
|
86 |
+
"""
|
87 |
+
This function creates and concatenates parts of the sequences based on the OpenAI scrambling strategy.
|
88 |
+
It randomly selects two indices within the range of the given tuple,
|
89 |
+
splits the sequence into three parts based on these indices, and then concatenates them with the
|
90 |
+
masked patch at the end
|
91 |
+
"""
|
92 |
+
new_tuple = tuple(np.sort(np.random.choice(np.arange(t[0] + 1, t[1]), 2, replace=False)))
|
93 |
+
part1 = sequences[t[0]:new_tuple[0]]
|
94 |
+
part2 = sequences[new_tuple[0]:new_tuple[1]]
|
95 |
+
part3 = sequences[new_tuple[1]:t[1]]
|
96 |
+
sequence = np.concatenate([part1, [self.mask_tokens["<mask-1>"]], part3, [self.mask_tokens["<mask-1>"]], part2])
|
97 |
+
position_ids_sequence = None
|
98 |
+
if self.add_position_ids:
|
99 |
+
position_ids = np.arange(t[0], t[1]) - t[0]
|
100 |
+
position_ids_part1 = position_ids[t[0]:new_tuple[0]]
|
101 |
+
position_ids_part2 = position_ids[new_tuple[0]:new_tuple[1]]
|
102 |
+
position_ids_part3 = position_ids[new_tuple[1]:t[1]]
|
103 |
+
position_ids_sequence = np.concatenate(
|
104 |
+
[position_ids_part1, [position_ids_part2[0]], position_ids_part3, [position_ids_part2[0]],
|
105 |
+
position_ids_part2])
|
106 |
+
|
107 |
+
return sequence, position_ids_sequence
|
108 |
+
|
109 |
+
|
110 |
+
class MultipleSpanFIM(AbstractFIM):
|
111 |
+
def __init__(self,
|
112 |
+
max_patches=5,
|
113 |
+
mask_fraction=0.2,
|
114 |
+
always_mask=False,
|
115 |
+
mask_tokens=MASK_TO_ID,
|
116 |
+
eos_token=AA_TO_ID["<eos>"],
|
117 |
+
add_position_ids=False,
|
118 |
+
troubleshoot=False):
|
119 |
+
super().__init__(max_patches, mask_fraction, always_mask, mask_tokens, eos_token, add_position_ids, troubleshoot)
|
120 |
+
|
121 |
+
def fim(self, sequences, t):
|
122 |
+
"""
|
123 |
+
This function creates and concatenates parts of the sequences based on the inpaint scrambling strategy.
|
124 |
+
It randomly selects `2*num_patches` indices within the range of the given tuple,
|
125 |
+
splits the sequence into unmasked and masked parts based on these indices, and then concatenates them.
|
126 |
+
The number of patches is sampled from a poisson distribution with upper limit `self.max_patches` and average 1.
|
127 |
+
The concatenation is done by joining all unmaksed parts (interleaved with mask tokens) and afterwards
|
128 |
+
all masked parts (interleaved with mask tokens). At the end of the unmasked parts, a special token is added
|
129 |
+
to indicate the end of the unmasked parts, and at the end of the masked parts, a special token is added
|
130 |
+
to indicate the end of the masked parts.
|
131 |
+
"""
|
132 |
+
# sample num_patches from a discrete poisson distribution with upper limit L
|
133 |
+
def sample_lengths(start, end):
|
134 |
+
"""
|
135 |
+
Sample a length uniformly from 1 to max_L*self.mask_fraction (must be bigger than 1).
|
136 |
+
If the length is larger than max_L, return max_L.
|
137 |
+
"""
|
138 |
+
max_L = end - start
|
139 |
+
length = np.random.randint(1, max(int(max_L * self.mask_fraction), 2))
|
140 |
+
return min(length, max_L)
|
141 |
+
|
142 |
+
# sample num_patches from a discrete poisson distribution with upper limit max_patches
|
143 |
+
num_patches = 1000
|
144 |
+
while num_patches > self.max_patches:
|
145 |
+
num_patches = np.random.poisson(1)
|
146 |
+
if self.always_mask:
|
147 |
+
num_patches = max(num_patches, 1)
|
148 |
+
# sample num_patches starting points for the masked positions (+ final position)
|
149 |
+
start_patches = list(np.sort(np.random.choice(np.arange(t[0] + 1, t[1]),
|
150 |
+
num_patches,
|
151 |
+
replace=False))) + [t[1]]
|
152 |
+
# sample num_patches lengths of the patches
|
153 |
+
len_patches = [sample_lengths(start_patches[i], start_patches[i + 1])
|
154 |
+
for i in range(len(start_patches) - 1)]
|
155 |
+
# create masked tuples with start and end indices of the patches
|
156 |
+
masked_tuples = [(start_patches[i], start_patches[i] + len_patches[i]) for i in range(len(start_patches) - 1)]
|
157 |
+
# split the sequences into unmasked and masked parts
|
158 |
+
unmasked_sequence, masked_sequence, unmasked_position_ids, masked_position_ids = self.split_sequences(sequences,
|
159 |
+
t,
|
160 |
+
masked_tuples)
|
161 |
+
|
162 |
+
if self.troubleshoot:
|
163 |
+
print(f"For sequence in {t}: sampled {num_patches=}, {start_patches=}, {len_patches=}, {masked_tuples=}")
|
164 |
+
# concatenate the unmasked and masked parts
|
165 |
+
return unmasked_sequence + masked_sequence, unmasked_position_ids + masked_position_ids if self.add_position_ids else None
|
166 |
+
|
167 |
+
def split_sequences(self, sequences, t, masked_tuples):
|
168 |
+
"""
|
169 |
+
This function splits the sequences into unmasked and masked parts based on the given tuples.
|
170 |
+
Args:
|
171 |
+
t (tuple): The start and end index of each sequence.
|
172 |
+
masked_tuples (list): A list of tuples specifying the indices for masked regions.
|
173 |
+
Returns:
|
174 |
+
unmasked_parts (list): The unmasked parts of the sequences interleaved with mask_tokens.
|
175 |
+
masked_parts (list): The masked parts of the sequences interleaved with mask_tokens.
|
176 |
+
"""
|
177 |
+
unmasked_parts, masked_parts = [], []
|
178 |
+
unmasked_positions, masked_positions = [], []
|
179 |
+
position_ids = None
|
180 |
+
start, end = t
|
181 |
+
if self.add_position_ids:
|
182 |
+
position_ids = np.arange(start, end) - start
|
183 |
+
for i, region in enumerate(masked_tuples):
|
184 |
+
mask_token = self.mask_tokens[f"<mask-{i + 1}>"]
|
185 |
+
unmasked_parts.extend(sequences[start:region[0]])
|
186 |
+
unmasked_parts.append(mask_token)
|
187 |
+
masked_parts.append(mask_token)
|
188 |
+
masked_parts.extend(sequences[region[0]:region[1]])
|
189 |
+
if self.add_position_ids:
|
190 |
+
unmasked_positions.extend(position_ids[start-t[0]:region[0]-t[0]])
|
191 |
+
unmasked_positions.append(position_ids[region[0]-t[0]])
|
192 |
+
masked_positions.append(position_ids[region[0]-t[0]])
|
193 |
+
masked_positions.extend(position_ids[region[0]-t[0]:region[1]-t[0]])
|
194 |
+
|
195 |
+
start = region[1]
|
196 |
+
unmasked_parts.extend(sequences[start:end])
|
197 |
+
if self.add_position_ids:
|
198 |
+
unmasked_positions.extend(position_ids[start-t[0]:end-t[0]])
|
199 |
+
if len(masked_tuples) > 0:
|
200 |
+
unmasked_parts.append(self.eos_token)
|
201 |
+
if self.add_position_ids:
|
202 |
+
unmasked_positions.append(0)
|
203 |
+
return unmasked_parts, masked_parts, unmasked_positions, masked_positions
|
protxlstm/generation.py
ADDED
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Original code from ProtMamba under Apache License 2.0.
|
2 |
+
#
|
3 |
+
# Modifications made by Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen
|
4 |
+
# - Add option to pass input state for generation
|
5 |
+
# - Add functions to generate sequences with xlstm
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from protxlstm.mamba_utils_generation import (
|
10 |
+
InferenceParams,
|
11 |
+
GenerationMixin,
|
12 |
+
GreedySearchDecoderOnlyOutput,
|
13 |
+
modify_logits_for_top_p_filtering,
|
14 |
+
modify_logits_for_min_p_filtering,
|
15 |
+
modify_logit_for_repetition_penalty,
|
16 |
+
SampleDecoderOnlyOutput,
|
17 |
+
update_graph_cache
|
18 |
+
)
|
19 |
+
|
20 |
+
from protxlstm.utils import AA_TO_ID, decode_sequence
|
21 |
+
|
22 |
+
def sample_safe(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0):
|
23 |
+
"""Sample from top-k logits.
|
24 |
+
Arguments:
|
25 |
+
logits: Tensor of shape (batch_size, vocab_size)
|
26 |
+
"""
|
27 |
+
if top_k == 1: # Short-circuit for greedy decoding
|
28 |
+
return logits.argmax(dim=-1)
|
29 |
+
else:
|
30 |
+
if top_p > 0.0:
|
31 |
+
assert top_p <= 1.0, "top-p should be in (0, 1]."
|
32 |
+
if top_k > 0:
|
33 |
+
top_k = min(top_k, logits.size(-1)) # Safety check
|
34 |
+
logits_top, indices = torch.topk(logits, top_k, dim=-1)
|
35 |
+
if temperature != 1.0:
|
36 |
+
logits_top /= temperature
|
37 |
+
modify_logits_for_top_p_filtering(logits_top, top_p)
|
38 |
+
|
39 |
+
return indices[
|
40 |
+
torch.arange(indices.shape[0], device=indices.device),
|
41 |
+
torch.multinomial(
|
42 |
+
torch.softmax(logits_top, dim=-1), num_samples=1
|
43 |
+
).squeeze(dim=-1),
|
44 |
+
]
|
45 |
+
else:
|
46 |
+
if min_p > 0.0:
|
47 |
+
logits_top = logits.clone()
|
48 |
+
max_prob = logits_top[..., 0].item()
|
49 |
+
min_prob = max_prob * min_p
|
50 |
+
modify_logits_for_min_p_filtering(logits_top, min_p)
|
51 |
+
if temperature != 1.0:
|
52 |
+
logits_top /= temperature
|
53 |
+
return torch.multinomial(
|
54 |
+
torch.softmax(logits_top, dim=-1), num_samples=1
|
55 |
+
).squeeze(dim=-1)
|
56 |
+
# Clone so that when we modify for top_p we don't change the original logits
|
57 |
+
logits_top = logits / temperature if temperature != 1.0 else logits.clone()
|
58 |
+
modify_logits_for_top_p_filtering(logits_top, top_p)
|
59 |
+
return torch.multinomial(
|
60 |
+
torch.softmax(logits_top, dim=-1), num_samples=1
|
61 |
+
).squeeze(dim=-1)
|
62 |
+
|
63 |
+
|
64 |
+
@torch.inference_mode()
|
65 |
+
def decode_safe(
|
66 |
+
input_ids,
|
67 |
+
position_ids,
|
68 |
+
seq_position_ids,
|
69 |
+
is_fim,
|
70 |
+
model,
|
71 |
+
max_length,
|
72 |
+
state=None,
|
73 |
+
top_k=1,
|
74 |
+
top_p=0.0,
|
75 |
+
min_p=0.0,
|
76 |
+
temperature=1.0,
|
77 |
+
repetition_penalty=1.0,
|
78 |
+
eos_token_id=None,
|
79 |
+
teacher_outputs=None,
|
80 |
+
vocab_size=None,
|
81 |
+
cg=False,
|
82 |
+
enable_timing=False,
|
83 |
+
streamer = None,
|
84 |
+
chunk_chunk_size = 2**15,
|
85 |
+
):
|
86 |
+
"""Decoding, either greedy or with top-k or top-p sampling.
|
87 |
+
If top-k = 0, don't limit the number of candidates (pure sampling).
|
88 |
+
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
|
89 |
+
then top-p.
|
90 |
+
We assume that all sequences in the same batch have the same length.
|
91 |
+
|
92 |
+
Arguments:
|
93 |
+
input_ids: (batch, seq_len)
|
94 |
+
max_length: int
|
95 |
+
is_fim: dictionary with mask indices and associated position indices
|
96 |
+
teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
|
97 |
+
logits, the next token is taken from the teacher_outputs. Useful for testing.
|
98 |
+
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
|
99 |
+
sequences: (batch, max_length)
|
100 |
+
scores: tuples of (batch, vocab_size)
|
101 |
+
"""
|
102 |
+
if streamer is not None:
|
103 |
+
streamer.put(input_ids.cpu())
|
104 |
+
|
105 |
+
batch_size, seqlen_og = input_ids.shape
|
106 |
+
teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
|
107 |
+
if cg:
|
108 |
+
if not hasattr(model, "_decoding_cache"):
|
109 |
+
model._decoding_cache = None
|
110 |
+
model._decoding_cache = update_graph_cache(
|
111 |
+
model,
|
112 |
+
model._decoding_cache,
|
113 |
+
batch_size,
|
114 |
+
seqlen_og,
|
115 |
+
max_length,
|
116 |
+
)
|
117 |
+
inference_params = model._decoding_cache.inference_params
|
118 |
+
inference_params.reset(max_length, batch_size)
|
119 |
+
else:
|
120 |
+
inference_params = InferenceParams(
|
121 |
+
max_seqlen=max_length, max_batch_size=batch_size
|
122 |
+
)
|
123 |
+
|
124 |
+
def get_logits(input_ids, position_ids, seq_position_ids, inference_params):
|
125 |
+
decoding = inference_params.seqlen_offset > 0
|
126 |
+
if not cg or not decoding:
|
127 |
+
logits = model(
|
128 |
+
input_ids,
|
129 |
+
position_ids=position_ids,
|
130 |
+
seq_position_ids=seq_position_ids,
|
131 |
+
inference_params=inference_params,
|
132 |
+
num_last_tokens=1,
|
133 |
+
).logits.squeeze(dim=1)
|
134 |
+
else:
|
135 |
+
logits = model._decoding_cache.run(
|
136 |
+
input_ids,
|
137 |
+
position_ids,
|
138 |
+
inference_params.seqlen_offset,
|
139 |
+
seq_position_ids=seq_position_ids,
|
140 |
+
).squeeze(dim=1)
|
141 |
+
return logits[..., :vocab_size] if vocab_size is not None else logits
|
142 |
+
|
143 |
+
def get_xlstm_logits_step(input_ids, position_ids, seq_position_ids, state):
|
144 |
+
|
145 |
+
if not input_ids.shape[1] == 1:
|
146 |
+
|
147 |
+
for i in range(input_ids.shape[1]):
|
148 |
+
if position_ids != None:
|
149 |
+
token_position_ids = position_ids[:,i:(i+1)]
|
150 |
+
else:
|
151 |
+
token_position_ids = None
|
152 |
+
if seq_position_ids != None:
|
153 |
+
token_seq_position_ids = seq_position_ids[:,i:(i+1)]
|
154 |
+
else:
|
155 |
+
token_seq_position_ids = None
|
156 |
+
logits, state = model.step(input_ids[:,i:(i+1)], state, position_ids=token_position_ids, seq_position_ids=token_seq_position_ids)
|
157 |
+
|
158 |
+
else:
|
159 |
+
logits, state = model.step(input_ids, state, position_ids=position_ids, seq_position_ids=seq_position_ids)
|
160 |
+
|
161 |
+
logits = logits.squeeze(dim=1)
|
162 |
+
if vocab_size is not None:
|
163 |
+
logits = logits[..., :vocab_size]
|
164 |
+
|
165 |
+
return logits, state
|
166 |
+
|
167 |
+
def get_xlstm_logits_chunkwise(input_ids, position_ids, seq_position_ids, chunk_chunk_size=2**15, state=None):
|
168 |
+
|
169 |
+
assert model.config.config_dataclass.mlstm_block.mlstm.backend == "chunkwise_variable"
|
170 |
+
|
171 |
+
for chunk in range(input_ids.shape[1]//chunk_chunk_size+1):
|
172 |
+
|
173 |
+
start_idx = chunk*chunk_chunk_size
|
174 |
+
end_idx = min((chunk+1)*chunk_chunk_size, input_ids.shape[1])
|
175 |
+
|
176 |
+
if start_idx == end_idx:
|
177 |
+
pass
|
178 |
+
|
179 |
+
else:
|
180 |
+
input_ids_chunk = input_ids[:, start_idx:end_idx]
|
181 |
+
|
182 |
+
if not position_ids == None:
|
183 |
+
position_ids_chunk = position_ids[:, start_idx:end_idx]
|
184 |
+
else:
|
185 |
+
position_ids_chunk = None
|
186 |
+
|
187 |
+
if not seq_position_ids == None:
|
188 |
+
seq_position_ids_chunk = seq_position_ids[:, start_idx:end_idx]
|
189 |
+
else:
|
190 |
+
seq_position_ids_chunk = None
|
191 |
+
|
192 |
+
outputs = model(input_ids_chunk, position_ids=position_ids_chunk, seq_position_ids=seq_position_ids_chunk, state=state)
|
193 |
+
logits, state = outputs.logits, outputs.state
|
194 |
+
|
195 |
+
logits = logits[:,-1,:]
|
196 |
+
logits = logits.squeeze(dim=1)
|
197 |
+
if vocab_size is not None:
|
198 |
+
logits = logits[..., :vocab_size]
|
199 |
+
|
200 |
+
return logits, state
|
201 |
+
|
202 |
+
def sample_tokens(logits, inference_params):
|
203 |
+
if (
|
204 |
+
teacher_outputs is None
|
205 |
+
or teacher_output_len <= inference_params.seqlen_offset
|
206 |
+
):
|
207 |
+
token = sample_safe(
|
208 |
+
logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature
|
209 |
+
)
|
210 |
+
else:
|
211 |
+
token = teacher_outputs[:, inference_params.seqlen_offset]
|
212 |
+
# return rearrange(token, "b -> b 1")
|
213 |
+
return token.unsqueeze(1)
|
214 |
+
|
215 |
+
def get_fim_position_id(
|
216 |
+
last_position_ids, sampled_tokens, is_fim, repeat_next=False
|
217 |
+
):
|
218 |
+
if type(is_fim) is dict:
|
219 |
+
val = int(last_position_ids) + 1
|
220 |
+
should_repeat_next = False
|
221 |
+
if is_fim and int(sampled_tokens) in is_fim:
|
222 |
+
val = is_fim[int(sampled_tokens)]
|
223 |
+
should_repeat_next = True
|
224 |
+
elif repeat_next:
|
225 |
+
val = int(last_position_ids)
|
226 |
+
return torch.full_like(last_position_ids, fill_value=val), should_repeat_next
|
227 |
+
else:
|
228 |
+
t = [get_fim_position_id(last_position_ids_, sampled_tokens_, is_fim_dict, repeat_next) for
|
229 |
+
(last_position_ids_, sampled_tokens_, is_fim_dict) in
|
230 |
+
zip(last_position_ids, sampled_tokens, is_fim)]
|
231 |
+
return torch.stack([t_[0] for t_ in t], dim=0), t[0][1]
|
232 |
+
|
233 |
+
def should_stop(current_token, inference_params):
|
234 |
+
if inference_params.seqlen_offset == 0:
|
235 |
+
return False
|
236 |
+
if eos_token_id is not None and (current_token == eos_token_id).any():
|
237 |
+
if current_token.shape[1] > 1:
|
238 |
+
raise NotImplementedError("Batched eos_token_id not supported")
|
239 |
+
return True
|
240 |
+
if inference_params.seqlen_offset >= max_length - 1:
|
241 |
+
return True
|
242 |
+
return False
|
243 |
+
|
244 |
+
start = torch.cuda.Event(enable_timing=enable_timing)
|
245 |
+
end = torch.cuda.Event(enable_timing=enable_timing)
|
246 |
+
|
247 |
+
if enable_timing:
|
248 |
+
start.record()
|
249 |
+
scores, sequences = [], [input_ids]
|
250 |
+
new_position_ids, new_seq_position_ids = [position_ids], [seq_position_ids]
|
251 |
+
sequences_cat = input_ids
|
252 |
+
repeat_next = False
|
253 |
+
if position_ids.shape[0] > 1:
|
254 |
+
raise NotImplementedError("Batched generation with position_ids not supported")
|
255 |
+
|
256 |
+
encode_context=True
|
257 |
+
while not should_stop(sequences[-1], inference_params):
|
258 |
+
|
259 |
+
from protxlstm.models.xlstm import xLSTMLMHeadModel
|
260 |
+
if isinstance(model, xLSTMLMHeadModel):
|
261 |
+
if encode_context:
|
262 |
+
with torch.no_grad():
|
263 |
+
logits, state = get_xlstm_logits_chunkwise(sequences[-1], new_position_ids[-1], new_seq_position_ids[-1], state=state, chunk_chunk_size=chunk_chunk_size)
|
264 |
+
encode_context = False
|
265 |
+
else:
|
266 |
+
logits, state = get_xlstm_logits_step(sequences[-1], new_position_ids[-1], new_seq_position_ids[-1], state=state)
|
267 |
+
else:
|
268 |
+
logits = get_logits(sequences[-1], new_position_ids[-1], new_seq_position_ids[-1], inference_params)
|
269 |
+
|
270 |
+
scores.append(logits)
|
271 |
+
|
272 |
+
inference_params.seqlen_offset += sequences[-1].shape[1]
|
273 |
+
if repetition_penalty == 1.0:
|
274 |
+
sampled_tokens = sample_tokens(scores[-1], inference_params)
|
275 |
+
else:
|
276 |
+
logits = modify_logit_for_repetition_penalty(
|
277 |
+
scores[-1].clone(), sequences_cat, repetition_penalty
|
278 |
+
)
|
279 |
+
sampled_tokens = sample_tokens(logits, inference_params)
|
280 |
+
sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)
|
281 |
+
sequences.append(sampled_tokens)
|
282 |
+
# Update position_ids
|
283 |
+
if position_ids is not None:
|
284 |
+
last_position_ids, repeat_next = get_fim_position_id(
|
285 |
+
new_position_ids[-1][:, -1:], sampled_tokens, is_fim, repeat_next
|
286 |
+
)
|
287 |
+
new_position_ids.append(last_position_ids)
|
288 |
+
# Update seq_position_ids
|
289 |
+
if seq_position_ids is not None:
|
290 |
+
new_seq_position_ids.append(new_seq_position_ids[-1][:, -1:])
|
291 |
+
|
292 |
+
if streamer is not None:
|
293 |
+
streamer.put(sampled_tokens.cpu())
|
294 |
+
if streamer is not None:
|
295 |
+
streamer.end()
|
296 |
+
if enable_timing:
|
297 |
+
end.record()
|
298 |
+
torch.cuda.synchronize()
|
299 |
+
print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
|
300 |
+
output_cls = (
|
301 |
+
GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
|
302 |
+
)
|
303 |
+
return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
|
304 |
+
|
305 |
+
|
306 |
+
class GenerationMixinSafe(GenerationMixin):
|
307 |
+
|
308 |
+
def generate(
|
309 |
+
self,
|
310 |
+
input_ids,
|
311 |
+
position_ids,
|
312 |
+
seq_position_ids,
|
313 |
+
is_fim=None,
|
314 |
+
state=None,
|
315 |
+
max_length=1,
|
316 |
+
top_k=1,
|
317 |
+
top_p=0.0,
|
318 |
+
min_p=0.0,
|
319 |
+
temperature=1.0,
|
320 |
+
return_dict_in_generate=False,
|
321 |
+
output_scores=False,
|
322 |
+
chunk_chunk_size=2**15,
|
323 |
+
**kwargs,
|
324 |
+
):
|
325 |
+
|
326 |
+
output = decode_safe(
|
327 |
+
input_ids,
|
328 |
+
position_ids,
|
329 |
+
seq_position_ids,
|
330 |
+
is_fim,
|
331 |
+
self,
|
332 |
+
max_length,
|
333 |
+
state=state,
|
334 |
+
top_k=top_k,
|
335 |
+
top_p=top_p,
|
336 |
+
min_p=min_p,
|
337 |
+
temperature=temperature,
|
338 |
+
chunk_chunk_size=chunk_chunk_size,
|
339 |
+
**kwargs,
|
340 |
+
)
|
341 |
+
if not output_scores:
|
342 |
+
output.scores = None
|
343 |
+
return output if return_dict_in_generate else output.sequences
|
344 |
+
|
345 |
+
|
346 |
+
def generate_sequence(model, tokens, position_ids=None, seq_position_ids=None, state=None, is_fim=False, max_length=2000, temperature=1., top_p=0.0, top_k=1,
|
347 |
+
return_dict_in_generate=False, output_scores=False, eos_token_id=AA_TO_ID["<cls>"], device="cuda", chunk_chunk_size=2**15):
|
348 |
+
"""Generating, either greedy or with top-k or top-p sampling.
|
349 |
+
If top-k = 0, don't limit the number of candidates (pure sampling).
|
350 |
+
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
|
351 |
+
then top-p. We assume that all sequences in the same batch have the same length.
|
352 |
+
"""
|
353 |
+
input_ids = tokens.to(device)
|
354 |
+
position_ids = position_ids.to(device) if position_ids is not None else None
|
355 |
+
seq_position_ids = seq_position_ids.to(device) if seq_position_ids is not None else None
|
356 |
+
# generate sequence
|
357 |
+
out = model.generate(input_ids=input_ids,
|
358 |
+
position_ids=position_ids,
|
359 |
+
seq_position_ids=seq_position_ids,
|
360 |
+
is_fim=is_fim,
|
361 |
+
state=state,
|
362 |
+
max_length=max_length,
|
363 |
+
temperature=temperature,
|
364 |
+
top_p=top_p,
|
365 |
+
top_k=top_k,
|
366 |
+
return_dict_in_generate=return_dict_in_generate,
|
367 |
+
output_scores=output_scores,
|
368 |
+
eos_token_id=eos_token_id,
|
369 |
+
chunk_chunk_size=chunk_chunk_size,
|
370 |
+
)
|
371 |
+
sequences = out.sequences
|
372 |
+
dic = {"input": [decode_sequence(seq) for seq in sequences[:, :input_ids.shape[-1]].cpu().numpy()],
|
373 |
+
"generated": [decode_sequence(seq) for seq in sequences[:, input_ids.shape[-1]:].cpu().numpy()],
|
374 |
+
"input_tokens": [seq for seq in sequences[:, :input_ids.shape[-1]].cpu().numpy()],
|
375 |
+
"generated_tokens": [seq for seq in sequences[:, input_ids.shape[-1]:].cpu().numpy()]}
|
376 |
+
if output_scores:
|
377 |
+
dic["scores"] = np.array([el.to(torch.float32).cpu().numpy() for el in out.scores]).transpose(1, 0, 2)
|
378 |
+
return dic
|
379 |
+
|
380 |
+
|
381 |
+
|
382 |
+
|
383 |
+
|
384 |
+
|
protxlstm/index.html
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 3.2 Final//EN">
|
2 |
+
<html>
|
3 |
+
<head>
|
4 |
+
<title>Index of /research/Bio-xLSTM/downloads/Prot-xLSTM/checkpoints/protxlstm_26M_30B</title>
|
5 |
+
</head>
|
6 |
+
<body>
|
7 |
+
<h1>Index of /research/Bio-xLSTM/downloads/Prot-xLSTM/checkpoints/protxlstm_26M_30B</h1>
|
8 |
+
<pre><img src="/icons/blank.gif" alt="Icon "> <a href="?C=N;O=D">Name</a> <a href="?C=M;O=A">Last modified</a> <a href="?C=S;O=A">Size</a> <a href="?C=D;O=A">Description</a><hr><img src="/icons/back.gif" alt="[PARENTDIR]"> <a href="/research/Bio-xLSTM/downloads/Prot-xLSTM/checkpoints/">Parent Directory</a> -
|
9 |
+
<img src="/icons/unknown.gif" alt="[ ]"> <a href="config.json">config.json</a> 2024-11-04 14:36 1.8K
|
10 |
+
<img src="/icons/unknown.gif" alt="[ ]"> <a href="optimizer.pt">optimizer.pt</a> 2024-11-04 14:36 198M
|
11 |
+
<img src="/icons/binary.gif" alt="[ ]"> <a href="pytorch_model.bin">pytorch_model.bin</a> 2024-11-04 14:36 99M
|
12 |
+
<img src="/icons/unknown.gif" alt="[ ]"> <a href="rng_state.pth">rng_state.pth</a> 2024-11-04 14:36 14K
|
13 |
+
<img src="/icons/unknown.gif" alt="[ ]"> <a href="scheduler.pt">scheduler.pt</a> 2024-11-04 14:36 1.0K
|
14 |
+
<img src="/icons/unknown.gif" alt="[ ]"> <a href="trainer_state.json">trainer_state.json</a> 2024-11-04 14:36 2.4M
|
15 |
+
<hr></pre>
|
16 |
+
</body></html>
|
protxlstm/mamba_utils_generation.py
ADDED
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# From: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/utils/generation.py
|
2 |
+
# Copyright (c) 2023, Albert Gu, Tri Dao.
|
3 |
+
import gc
|
4 |
+
from dataclasses import dataclass, field
|
5 |
+
from typing import Callable, Optional
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import Tensor
|
9 |
+
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class InferenceParams:
|
14 |
+
"""Inference parameters that are passed to the main model in order
|
15 |
+
to efficienly calculate and store the context during inference."""
|
16 |
+
|
17 |
+
max_seqlen: int
|
18 |
+
max_batch_size: int
|
19 |
+
seqlen_offset: int = 0
|
20 |
+
batch_size_offset: int = 0
|
21 |
+
key_value_memory_dict: dict = field(default_factory=dict)
|
22 |
+
lengths_per_sample: Optional[Tensor] = None
|
23 |
+
|
24 |
+
def reset(self, max_seqlen, max_batch_size):
|
25 |
+
self.max_seqlen = max_seqlen
|
26 |
+
self.max_batch_size = max_batch_size
|
27 |
+
self.seqlen_offset = 0
|
28 |
+
if self.lengths_per_sample is not None:
|
29 |
+
self.lengths_per_sample.zero_()
|
30 |
+
|
31 |
+
|
32 |
+
def modify_logits_for_min_p_filtering(logits, min_p):
|
33 |
+
"""Set the logits for none min_p values to -inf. Done in-place."""
|
34 |
+
if min_p <= 0.0 or min_p >= 1.0:
|
35 |
+
return
|
36 |
+
indices_to_remove = logits < min_p
|
37 |
+
logits.masked_fill_(indices_to_remove, float("-Inf"))
|
38 |
+
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
|
39 |
+
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
|
40 |
+
def modify_logits_for_top_k_filtering(logits, top_k):
|
41 |
+
"""Set the logits for none top-k values to -inf. Done in-place."""
|
42 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
43 |
+
logits.masked_fill_(indices_to_remove, float("-Inf"))
|
44 |
+
|
45 |
+
|
46 |
+
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
|
47 |
+
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
|
48 |
+
def modify_logits_for_top_p_filtering(logits, top_p):
|
49 |
+
"""Set the logits for none top-p values to -inf. Done in-place."""
|
50 |
+
if top_p <= 0.0 or top_p >= 1.0:
|
51 |
+
return
|
52 |
+
# First sort and calculate cumulative sum of probabilities.
|
53 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
|
54 |
+
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
55 |
+
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
56 |
+
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
|
57 |
+
# scatter sorted tensors to original indexing
|
58 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
59 |
+
1, sorted_indices, sorted_indices_to_remove
|
60 |
+
)
|
61 |
+
logits.masked_fill_(indices_to_remove, float("-inf"))
|
62 |
+
|
63 |
+
|
64 |
+
def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_penalty=1.0):
|
65 |
+
"""Apply repetition penalty. See https://arxiv.org/abs/1909.05858
|
66 |
+
logits: (batch_size, vocab_size)
|
67 |
+
prev_output_tokens: (batch_size, seq_len)
|
68 |
+
"""
|
69 |
+
if repetition_penalty == 1.0:
|
70 |
+
return logits
|
71 |
+
score = torch.gather(logits, 1, prev_output_tokens)
|
72 |
+
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
73 |
+
score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
|
74 |
+
logits.scatter_(1, prev_output_tokens, score)
|
75 |
+
return logits
|
76 |
+
|
77 |
+
|
78 |
+
def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0):
|
79 |
+
"""Sample from top-k logits.
|
80 |
+
Arguments:
|
81 |
+
logits: Tensor of shape (batch_size, vocab_size)
|
82 |
+
"""
|
83 |
+
if top_k == 1: # Short-circuit for greedy decoding
|
84 |
+
return logits.argmax(dim=-1)
|
85 |
+
else:
|
86 |
+
if top_p > 0.0:
|
87 |
+
assert top_p <= 1.0, "top-p should be in (0, 1]."
|
88 |
+
if top_k > 0:
|
89 |
+
top_k = min(top_k, logits.size(-1)) # Safety check
|
90 |
+
logits_top, indices = torch.topk(logits, top_k, dim=-1)
|
91 |
+
if temperature != 1.0:
|
92 |
+
logits_top /= temperature
|
93 |
+
modify_logits_for_top_p_filtering(logits_top, top_p)
|
94 |
+
return indices[
|
95 |
+
torch.arange(indices.shape[0], device=indices.device),
|
96 |
+
torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
|
97 |
+
]
|
98 |
+
else:
|
99 |
+
if min_p > 0.0:
|
100 |
+
logits_top = logits.clone()
|
101 |
+
max_prob = logits_top[..., 0].item()
|
102 |
+
min_prob = max_prob * min_p
|
103 |
+
modify_logits_for_min_p_filtering(logits_top, min_prob)
|
104 |
+
if temperature != 1.0:
|
105 |
+
logits_top /= temperature
|
106 |
+
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
|
107 |
+
# Clone so that when we modify for top_p we don't change the original logits
|
108 |
+
logits_top = logits / temperature if temperature != 1.0 else logits.clone()
|
109 |
+
modify_logits_for_top_p_filtering(logits_top, top_p)
|
110 |
+
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
|
111 |
+
dim=-1
|
112 |
+
)
|
113 |
+
|
114 |
+
|
115 |
+
@torch.inference_mode()
|
116 |
+
def decode(
|
117 |
+
input_ids,
|
118 |
+
model,
|
119 |
+
max_length,
|
120 |
+
top_k=1,
|
121 |
+
top_p=0.0,
|
122 |
+
min_p=0.0,
|
123 |
+
temperature=1.0,
|
124 |
+
repetition_penalty=1.0,
|
125 |
+
eos_token_id=None,
|
126 |
+
teacher_outputs=None,
|
127 |
+
vocab_size=None,
|
128 |
+
cg=False,
|
129 |
+
enable_timing=False,
|
130 |
+
streamer: Optional[TextStreamer] = None
|
131 |
+
):
|
132 |
+
"""Decoding, either greedy or with top-k or top-p sampling.
|
133 |
+
If top-k = 0, don't limit the number of candidates (pure sampling).
|
134 |
+
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
|
135 |
+
then top-p.
|
136 |
+
We assume that all sequences in the same batch have the same length.
|
137 |
+
|
138 |
+
Arguments:
|
139 |
+
input_ids: (batch, seq_len)
|
140 |
+
max_length: int
|
141 |
+
teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
|
142 |
+
logits, the next token is taken from the teacher_outputs. Useful for testing.
|
143 |
+
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
|
144 |
+
sequences: (batch, max_length)
|
145 |
+
scores: tuples of (batch, vocab_size)
|
146 |
+
"""
|
147 |
+
if streamer is not None:
|
148 |
+
streamer.put(input_ids.cpu())
|
149 |
+
|
150 |
+
batch_size, seqlen_og = input_ids.shape
|
151 |
+
teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
|
152 |
+
if cg:
|
153 |
+
if not hasattr(model, "_decoding_cache"):
|
154 |
+
model._decoding_cache = None
|
155 |
+
model._decoding_cache = update_graph_cache(
|
156 |
+
model,
|
157 |
+
model._decoding_cache,
|
158 |
+
batch_size,
|
159 |
+
seqlen_og,
|
160 |
+
max_length,
|
161 |
+
)
|
162 |
+
inference_params = model._decoding_cache.inference_params
|
163 |
+
inference_params.reset(max_length, batch_size)
|
164 |
+
else:
|
165 |
+
inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
|
166 |
+
|
167 |
+
def get_logits(input_ids, inference_params):
|
168 |
+
decoding = inference_params.seqlen_offset > 0
|
169 |
+
if decoding:
|
170 |
+
position_ids = torch.full(
|
171 |
+
(batch_size, 1),
|
172 |
+
inference_params.seqlen_offset,
|
173 |
+
dtype=torch.long,
|
174 |
+
device=input_ids.device,
|
175 |
+
)
|
176 |
+
else:
|
177 |
+
position_ids = None
|
178 |
+
if not cg or not decoding:
|
179 |
+
logits = model(
|
180 |
+
input_ids,
|
181 |
+
position_ids=position_ids,
|
182 |
+
inference_params=inference_params,
|
183 |
+
num_last_tokens=1,
|
184 |
+
).logits.squeeze(dim=1)
|
185 |
+
else:
|
186 |
+
logits = model._decoding_cache.run(
|
187 |
+
input_ids, position_ids, inference_params.seqlen_offset
|
188 |
+
).squeeze(dim=1)
|
189 |
+
return logits[..., :vocab_size] if vocab_size is not None else logits
|
190 |
+
|
191 |
+
def sample_tokens(logits, inference_params):
|
192 |
+
if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
|
193 |
+
token = sample(logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature)
|
194 |
+
else:
|
195 |
+
token = teacher_outputs[:, inference_params.seqlen_offset]
|
196 |
+
# return rearrange(token, "b -> b 1")
|
197 |
+
return token.unsqueeze(1)
|
198 |
+
|
199 |
+
def should_stop(current_token, inference_params):
|
200 |
+
if inference_params.seqlen_offset == 0:
|
201 |
+
return False
|
202 |
+
if eos_token_id is not None and (current_token == eos_token_id).all():
|
203 |
+
return True
|
204 |
+
if inference_params.seqlen_offset >= max_length - 1:
|
205 |
+
return True
|
206 |
+
return False
|
207 |
+
|
208 |
+
start = torch.cuda.Event(enable_timing=enable_timing)
|
209 |
+
end = torch.cuda.Event(enable_timing=enable_timing)
|
210 |
+
|
211 |
+
if enable_timing:
|
212 |
+
start.record()
|
213 |
+
scores, sequences = [], [input_ids]
|
214 |
+
sequences_cat = input_ids
|
215 |
+
while not should_stop(sequences[-1], inference_params):
|
216 |
+
scores.append(get_logits(sequences[-1], inference_params))
|
217 |
+
inference_params.seqlen_offset += sequences[-1].shape[1]
|
218 |
+
if repetition_penalty == 1.0:
|
219 |
+
sampled_tokens = sample_tokens(scores[-1], inference_params)
|
220 |
+
else:
|
221 |
+
logits = modify_logit_for_repetition_penalty(
|
222 |
+
scores[-1].clone(), sequences_cat, repetition_penalty
|
223 |
+
)
|
224 |
+
sampled_tokens = sample_tokens(logits, inference_params)
|
225 |
+
sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)
|
226 |
+
sequences.append(sampled_tokens)
|
227 |
+
if streamer is not None:
|
228 |
+
streamer.put(sampled_tokens.cpu())
|
229 |
+
if streamer is not None:
|
230 |
+
streamer.end()
|
231 |
+
if enable_timing:
|
232 |
+
end.record()
|
233 |
+
torch.cuda.synchronize()
|
234 |
+
print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
|
235 |
+
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
|
236 |
+
return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
|
237 |
+
|
238 |
+
|
239 |
+
class GenerationMixin:
|
240 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
241 |
+
raise NotImplementedError
|
242 |
+
|
243 |
+
def generate(
|
244 |
+
self,
|
245 |
+
input_ids,
|
246 |
+
max_length,
|
247 |
+
top_k=1,
|
248 |
+
top_p=0.0,
|
249 |
+
min_p=0.0,
|
250 |
+
temperature=1.0,
|
251 |
+
return_dict_in_generate=False,
|
252 |
+
output_scores=False,
|
253 |
+
**kwargs,
|
254 |
+
):
|
255 |
+
output = decode(
|
256 |
+
input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, **kwargs
|
257 |
+
)
|
258 |
+
if not output_scores:
|
259 |
+
output.scores = None
|
260 |
+
return output if return_dict_in_generate else output.sequences
|
261 |
+
|
262 |
+
|
263 |
+
@dataclass
|
264 |
+
class DecodingCGCache:
|
265 |
+
max_batch_size: int = 0
|
266 |
+
max_seqlen: int = 0
|
267 |
+
device = None
|
268 |
+
dtype = None
|
269 |
+
callables: dict = field(default_factory=dict)
|
270 |
+
mempool = None
|
271 |
+
inference_params: Optional[InferenceParams] = None
|
272 |
+
run: Optional[Callable] = None
|
273 |
+
|
274 |
+
|
275 |
+
@torch.inference_mode()
|
276 |
+
def update_graph_cache(
|
277 |
+
model,
|
278 |
+
cache,
|
279 |
+
batch_size,
|
280 |
+
seqlen_og,
|
281 |
+
max_seqlen,
|
282 |
+
decoding_seqlens=(1,),
|
283 |
+
dtype=None,
|
284 |
+
n_warmups=2,
|
285 |
+
):
|
286 |
+
if cache is None:
|
287 |
+
cache = DecodingCGCache()
|
288 |
+
param_example = next(iter(model.parameters()))
|
289 |
+
device = param_example.device
|
290 |
+
if dtype is None:
|
291 |
+
dtype = param_example.dtype
|
292 |
+
if (
|
293 |
+
(device, dtype) != (cache.device, cache.dtype)
|
294 |
+
or batch_size > cache.max_batch_size
|
295 |
+
or max_seqlen > cache.max_seqlen
|
296 |
+
): # Invalidate the cache
|
297 |
+
cache.callables = {}
|
298 |
+
cache.mempool = None
|
299 |
+
cache.inference_params = None
|
300 |
+
gc.collect()
|
301 |
+
cache.device, cache.dtype = device, dtype
|
302 |
+
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
|
303 |
+
assert hasattr(model, "allocate_inference_cache"), "CUDA graph decoding requires that the model has a method allocate_inference_cache"
|
304 |
+
inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
|
305 |
+
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
|
306 |
+
cache.inference_params = InferenceParams(
|
307 |
+
max_seqlen=max_seqlen,
|
308 |
+
max_batch_size=batch_size,
|
309 |
+
seqlen_offset=seqlen_og,
|
310 |
+
key_value_memory_dict=inf_cache,
|
311 |
+
lengths_per_sample=lengths_per_sample,
|
312 |
+
)
|
313 |
+
cache.mempool = torch.cuda.graphs.graph_pool_handle()
|
314 |
+
for decoding_seqlen in decoding_seqlens:
|
315 |
+
if (batch_size, decoding_seqlen) not in cache.callables:
|
316 |
+
cache.callables[batch_size, decoding_seqlen] = capture_graph(
|
317 |
+
model,
|
318 |
+
cache.inference_params,
|
319 |
+
batch_size,
|
320 |
+
max_seqlen,
|
321 |
+
decoding_seqlen=decoding_seqlen,
|
322 |
+
mempool=cache.mempool,
|
323 |
+
n_warmups=n_warmups,
|
324 |
+
)
|
325 |
+
|
326 |
+
def dispatch(input_ids, position_ids, seqlen):
|
327 |
+
batch_size, decoding_seqlen = input_ids.shape[:2]
|
328 |
+
return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen)
|
329 |
+
|
330 |
+
cache.run = dispatch
|
331 |
+
cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing
|
332 |
+
return cache
|
333 |
+
|
334 |
+
|
335 |
+
def capture_graph(
|
336 |
+
model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2
|
337 |
+
):
|
338 |
+
device = next(iter(model.parameters())).device
|
339 |
+
input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
|
340 |
+
position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
|
341 |
+
seqlen_offset_og = inference_params.seqlen_offset
|
342 |
+
inference_params.seqlen_offset = max_seqlen - decoding_seqlen
|
343 |
+
inference_params.lengths_per_sample[:] = inference_params.seqlen_offset
|
344 |
+
|
345 |
+
# Warmup before capture
|
346 |
+
s = torch.cuda.Stream()
|
347 |
+
s.wait_stream(torch.cuda.current_stream())
|
348 |
+
with torch.cuda.stream(s):
|
349 |
+
for _ in range(n_warmups):
|
350 |
+
logits = model(
|
351 |
+
input_ids,
|
352 |
+
position_ids=position_ids,
|
353 |
+
inference_params=inference_params,
|
354 |
+
num_last_tokens=decoding_seqlen,
|
355 |
+
).logits
|
356 |
+
s.synchronize()
|
357 |
+
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
|
358 |
+
# which requires that graph launch and non-captured launch to not overlap (I think,
|
359 |
+
# that's how I interpret the documentation). I'm not sure if this is required.
|
360 |
+
if torch.distributed.is_initialized():
|
361 |
+
torch.distributed.barrier()
|
362 |
+
torch.cuda.current_stream().wait_stream(s)
|
363 |
+
# Captures the graph
|
364 |
+
# To allow capture, automatically sets a side stream as the current stream in the context
|
365 |
+
graph = torch.cuda.CUDAGraph()
|
366 |
+
with torch.cuda.graph(graph, pool=mempool):
|
367 |
+
logits = model(
|
368 |
+
input_ids,
|
369 |
+
position_ids=position_ids,
|
370 |
+
inference_params=inference_params,
|
371 |
+
num_last_tokens=decoding_seqlen,
|
372 |
+
).logits
|
373 |
+
|
374 |
+
def run(new_input_ids, new_position_ids, seqlen):
|
375 |
+
inference_params.lengths_per_sample[:] = seqlen
|
376 |
+
input_ids.copy_(new_input_ids)
|
377 |
+
position_ids.copy_(new_position_ids)
|
378 |
+
graph.replay()
|
379 |
+
return logits.clone()
|
380 |
+
|
381 |
+
inference_params.seqlen_offset = seqlen_offset_og
|
382 |
+
return run
|
protxlstm/models/__init__.py
ADDED
File without changes
|
protxlstm/models/llama.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
from collections import namedtuple
|
5 |
+
from typing import Optional, Tuple
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from transformers import PretrainedConfig
|
11 |
+
|
12 |
+
from protxlstm.xlstm.components.rotary_position import compute_freqs_cis
|
13 |
+
|
14 |
+
# Note: generation capabilities are not implemented for the transformer
|
15 |
+
|
16 |
+
class TransformerConfig(PretrainedConfig):
|
17 |
+
|
18 |
+
model_type = "llama"
|
19 |
+
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
d_model,
|
23 |
+
n_layer,
|
24 |
+
n_heads,
|
25 |
+
n_kv_heads,
|
26 |
+
bidirectional,
|
27 |
+
vocab_size,
|
28 |
+
hidden_dim,
|
29 |
+
multiple_of, # MLP hidden layer size will be multiple of
|
30 |
+
norm_eps,
|
31 |
+
max_length,
|
32 |
+
dropout,
|
33 |
+
max_position_embeddings,
|
34 |
+
rope_base_frequency,
|
35 |
+
**kwargs
|
36 |
+
):
|
37 |
+
super().__init__(**kwargs)
|
38 |
+
|
39 |
+
# default hyperparameters for the Llama 7B model
|
40 |
+
self.dim = d_model
|
41 |
+
self.n_layers = n_layer
|
42 |
+
self.n_heads = n_heads
|
43 |
+
self.n_kv_heads = n_kv_heads
|
44 |
+
self.causal_attention = not bidirectional
|
45 |
+
self.vocab_size = vocab_size
|
46 |
+
self.hidden_dim = hidden_dim
|
47 |
+
self.multiple_of = multiple_of
|
48 |
+
self.norm_eps = norm_eps
|
49 |
+
self.max_seq_len = max_length
|
50 |
+
self.dropout = dropout
|
51 |
+
self.max_position_embeddings = max_position_embeddings
|
52 |
+
self.rope_base_frequency = rope_base_frequency
|
53 |
+
|
54 |
+
class RMSNorm_transformer(torch.nn.Module):
|
55 |
+
def __init__(self, dim: int, eps: float):
|
56 |
+
super().__init__()
|
57 |
+
self.eps = eps
|
58 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
59 |
+
|
60 |
+
def _norm(self, x):
|
61 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
output = self._norm(x.float()).type_as(x)
|
65 |
+
return output * self.weight
|
66 |
+
|
67 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
68 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
69 |
+
t = torch.arange(end, device=freqs.device) # type: ignore
|
70 |
+
freqs = torch.outer(t, freqs).float() # type: ignore
|
71 |
+
freqs_cos = torch.cos(freqs) # real part
|
72 |
+
freqs_sin = torch.sin(freqs) # imaginary part
|
73 |
+
return freqs_cos, freqs_sin
|
74 |
+
|
75 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
76 |
+
ndim = x.ndim
|
77 |
+
assert 0 <= 1 < ndim
|
78 |
+
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
79 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
80 |
+
return freqs_cis.view(shape)
|
81 |
+
|
82 |
+
def apply_rotary_emb(
|
83 |
+
xq: torch.Tensor,
|
84 |
+
xk: torch.Tensor,
|
85 |
+
freqs_cos: torch.Tensor,
|
86 |
+
freqs_sin: torch.Tensor
|
87 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
88 |
+
|
89 |
+
# reshape xq and xk to match the complex representation
|
90 |
+
xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
|
91 |
+
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
|
92 |
+
|
93 |
+
# reshape freqs_cos and freqs_sin for broadcasting
|
94 |
+
freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
|
95 |
+
freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
|
96 |
+
|
97 |
+
# apply rotation using real numbers
|
98 |
+
xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
|
99 |
+
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
|
100 |
+
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
|
101 |
+
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
|
102 |
+
|
103 |
+
# flatten last two dimensions
|
104 |
+
xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
|
105 |
+
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
|
106 |
+
|
107 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
108 |
+
|
109 |
+
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
110 |
+
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
111 |
+
bs, slen, n_kv_heads, head_dim = x.shape
|
112 |
+
if n_rep == 1:
|
113 |
+
return x
|
114 |
+
return (
|
115 |
+
x[:, :, :, None, :]
|
116 |
+
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
117 |
+
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
118 |
+
)
|
119 |
+
|
120 |
+
class Attention(nn.Module):
|
121 |
+
def __init__(self, args: TransformerConfig):
|
122 |
+
super().__init__()
|
123 |
+
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
124 |
+
assert args.n_heads % self.n_kv_heads == 0
|
125 |
+
model_parallel_size = 1
|
126 |
+
self.n_local_heads = args.n_heads // model_parallel_size
|
127 |
+
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
128 |
+
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
129 |
+
self.head_dim = args.dim // args.n_heads
|
130 |
+
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
|
131 |
+
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
132 |
+
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
133 |
+
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
|
134 |
+
self.attn_dropout = nn.Dropout(args.dropout)
|
135 |
+
self.resid_dropout = nn.Dropout(args.dropout)
|
136 |
+
self.dropout = args.dropout
|
137 |
+
self.causal_attention = args.causal_attention
|
138 |
+
|
139 |
+
# use flash attention or a manual implementation?
|
140 |
+
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
141 |
+
if not self.flash and self.causal_attention:
|
142 |
+
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
143 |
+
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
144 |
+
mask = torch.triu(mask, diagonal=1)
|
145 |
+
self.register_buffer("mask", mask)
|
146 |
+
|
147 |
+
def forward(
|
148 |
+
self,
|
149 |
+
x: torch.Tensor,
|
150 |
+
freqs_cos: torch.Tensor,
|
151 |
+
freqs_sin: torch.Tensor,
|
152 |
+
):
|
153 |
+
bsz, seqlen, _ = x.shape
|
154 |
+
|
155 |
+
# QKV
|
156 |
+
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
157 |
+
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
158 |
+
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
159 |
+
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
160 |
+
|
161 |
+
# RoPE relative positional embeddings
|
162 |
+
xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
|
163 |
+
|
164 |
+
# grouped multiquery attention: expand out keys and values
|
165 |
+
xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
166 |
+
xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
167 |
+
|
168 |
+
# make heads into a batch dimension
|
169 |
+
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
|
170 |
+
xk = xk.transpose(1, 2)
|
171 |
+
xv = xv.transpose(1, 2)
|
172 |
+
|
173 |
+
# flash implementation
|
174 |
+
if self.flash:
|
175 |
+
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=self.causal_attention)
|
176 |
+
else:
|
177 |
+
# manual implementation
|
178 |
+
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
|
179 |
+
if self.causal_attention:
|
180 |
+
scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
|
181 |
+
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
182 |
+
scores = self.attn_dropout(scores)
|
183 |
+
output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)
|
184 |
+
|
185 |
+
# restore time as batch dimension and concat heads
|
186 |
+
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
187 |
+
|
188 |
+
# final projection into the residual stream
|
189 |
+
output = self.wo(output)
|
190 |
+
output = self.resid_dropout(output)
|
191 |
+
return output
|
192 |
+
|
193 |
+
class FeedForward(nn.Module):
|
194 |
+
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
|
195 |
+
super().__init__()
|
196 |
+
if hidden_dim is None:
|
197 |
+
hidden_dim = 4 * dim
|
198 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
199 |
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
200 |
+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
201 |
+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
202 |
+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
203 |
+
self.dropout = nn.Dropout(dropout)
|
204 |
+
|
205 |
+
def forward(self, x):
|
206 |
+
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
|
207 |
+
|
208 |
+
class TransformerBlock(nn.Module):
|
209 |
+
def __init__(self, layer_id: int, args: TransformerConfig):
|
210 |
+
super().__init__()
|
211 |
+
self.n_heads = args.n_heads
|
212 |
+
self.dim = args.dim
|
213 |
+
self.head_dim = args.dim // args.n_heads
|
214 |
+
self.attention = Attention(args)
|
215 |
+
self.feed_forward = FeedForward(
|
216 |
+
dim=args.dim,
|
217 |
+
hidden_dim=args.hidden_dim,
|
218 |
+
multiple_of=args.multiple_of,
|
219 |
+
dropout=args.dropout,
|
220 |
+
)
|
221 |
+
self.layer_id = layer_id
|
222 |
+
self.attention_norm = RMSNorm_transformer(args.dim, eps=args.norm_eps)
|
223 |
+
self.ffn_norm = RMSNorm_transformer(args.dim, eps=args.norm_eps)
|
224 |
+
|
225 |
+
def forward(self, x, freqs_cos, freqs_sin):
|
226 |
+
h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
|
227 |
+
out = h + self.feed_forward.forward(self.ffn_norm(h))
|
228 |
+
return out
|
229 |
+
|
230 |
+
class Transformer(nn.Module):
|
231 |
+
|
232 |
+
last_loss: Optional[torch.Tensor]
|
233 |
+
|
234 |
+
def __init__(self, params: TransformerConfig):
|
235 |
+
super().__init__()
|
236 |
+
self.params = params
|
237 |
+
self.vocab_size = params.vocab_size
|
238 |
+
self.n_layers = params.n_layers
|
239 |
+
|
240 |
+
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
|
241 |
+
self.dropout = nn.Dropout(params.dropout)
|
242 |
+
self.layers = torch.nn.ModuleList()
|
243 |
+
for layer_id in range(params.n_layers):
|
244 |
+
self.layers.append(TransformerBlock(layer_id, params))
|
245 |
+
self.layer_head_dim = self.layers[0].head_dim
|
246 |
+
|
247 |
+
self.norm = RMSNorm_transformer(params.dim, eps=params.norm_eps)
|
248 |
+
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
249 |
+
|
250 |
+
# share the unembedding parameters with the embedding parameters
|
251 |
+
self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying
|
252 |
+
|
253 |
+
# some useful precompute for the RoPE relative positional embeddings
|
254 |
+
# freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)
|
255 |
+
# self.register_buffer("freqs_cos", freqs_cos, persistent=False)
|
256 |
+
# self.register_buffer("freqs_sin", freqs_sin, persistent=False)
|
257 |
+
|
258 |
+
# init all weights
|
259 |
+
self.apply(self._init_weights)
|
260 |
+
# apply special scaled init to the residual projections, per GPT-2 paper
|
261 |
+
for pn, p in self.named_parameters():
|
262 |
+
if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
|
263 |
+
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * params.n_layers))
|
264 |
+
|
265 |
+
# Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor.
|
266 |
+
self.last_loss = None
|
267 |
+
|
268 |
+
def _init_weights(self, module):
|
269 |
+
if isinstance(module, nn.Linear):
|
270 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
271 |
+
if module.bias is not None:
|
272 |
+
torch.nn.init.zeros_(module.bias)
|
273 |
+
elif isinstance(module, nn.Embedding):
|
274 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
275 |
+
|
276 |
+
def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
|
277 |
+
_bsz, seqlen = tokens.shape
|
278 |
+
h = self.tok_embeddings(tokens)
|
279 |
+
h = self.dropout(h)
|
280 |
+
# freqs_cos = self.freqs_cos[:seqlen]
|
281 |
+
# freqs_sin = self.freqs_sin[:seqlen]
|
282 |
+
|
283 |
+
if 'position_ids' in kwargs:
|
284 |
+
freqs_cos, freqs_sin = compute_freqs_cis(kwargs.pop("position_ids"), self.layer_head_dim, theta=self.params.rope_base_frequency)
|
285 |
+
else:
|
286 |
+
raise ValueError('Llama model only implemented with RoPEs')
|
287 |
+
|
288 |
+
freqs_cos = freqs_cos.squeeze()
|
289 |
+
freqs_sin = freqs_sin.squeeze()
|
290 |
+
|
291 |
+
for layer in self.layers:
|
292 |
+
h = layer(h, freqs_cos, freqs_sin)
|
293 |
+
h = self.norm(h)
|
294 |
+
|
295 |
+
if targets is not None:
|
296 |
+
logits = self.output(h)
|
297 |
+
self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
298 |
+
else:
|
299 |
+
logits = self.output(h)
|
300 |
+
self.last_loss = None
|
301 |
+
|
302 |
+
return logits
|
303 |
+
|
304 |
+
class TransformerLMHeadModel(nn.Module):
|
305 |
+
|
306 |
+
def __init__(
|
307 |
+
self,
|
308 |
+
config: TransformerConfig,
|
309 |
+
) -> None:
|
310 |
+
|
311 |
+
super().__init__()
|
312 |
+
|
313 |
+
self.config = config
|
314 |
+
|
315 |
+
self.backbone = Transformer(config)
|
316 |
+
|
317 |
+
def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
|
318 |
+
"""
|
319 |
+
num_last_tokens: if > 0, only return the logits for the last n tokens
|
320 |
+
"""
|
321 |
+
|
322 |
+
lm_logits = self.backbone(input_ids, position_ids=position_ids)
|
323 |
+
|
324 |
+
CausalLMOutput = namedtuple("CausalLMOutput", ["loss", "logits"])
|
325 |
+
return CausalLMOutput(loss=None, logits=lm_logits)
|
326 |
+
|
327 |
+
def save_pretrained(self, save_directory):
|
328 |
+
"""
|
329 |
+
Save the model and its configuration file to a directory.
|
330 |
+
"""
|
331 |
+
|
332 |
+
# Ensure save_directory exists
|
333 |
+
os.makedirs(save_directory, exist_ok=True)
|
334 |
+
|
335 |
+
# Save the model's state_dict
|
336 |
+
model_path = os.path.join(save_directory, "pytorch_model.bin")
|
337 |
+
torch.save(self.state_dict(), model_path)
|
338 |
+
|
339 |
+
# Save the configuration of the model
|
340 |
+
config_path = os.path.join(save_directory, "config.json")
|
341 |
+
with open(config_path, "w") as f:
|
342 |
+
json.dump(self.config.to_dict(), f)
|
protxlstm/models/mamba.py
ADDED
@@ -0,0 +1,833 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Original code from ProtMamba under Apache License 2.0.
|
2 |
+
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
from collections import namedtuple
|
6 |
+
from dataclasses import dataclass, field
|
7 |
+
from functools import partial
|
8 |
+
|
9 |
+
from mamba_ssm.models.config_mamba import MambaConfig
|
10 |
+
from mamba_ssm.modules.mamba_simple import Block, Mamba
|
11 |
+
from mamba_ssm.models.mixer_seq_simple import MixerModel, _init_weights
|
12 |
+
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
|
13 |
+
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
from torch.utils.checkpoint import checkpoint
|
17 |
+
from transformers import PretrainedConfig
|
18 |
+
|
19 |
+
from protxlstm.generation import GenerationMixinSafe
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class MambaConfig(PretrainedConfig):
|
23 |
+
d_model: int = 2560
|
24 |
+
n_layer: int = 64
|
25 |
+
vocab_size: int = 50277
|
26 |
+
ssm_cfg: dict = field(default_factory=dict)
|
27 |
+
rms_norm: bool = True
|
28 |
+
residual_in_fp32: bool = True
|
29 |
+
fused_add_norm: bool = True
|
30 |
+
pad_vocab_size_multiple: int = 8
|
31 |
+
max_position_embeddings: int = 2048
|
32 |
+
|
33 |
+
def create_block(
|
34 |
+
d_model,
|
35 |
+
ssm_cfg=None,
|
36 |
+
norm_epsilon=1e-5,
|
37 |
+
rms_norm=False,
|
38 |
+
residual_in_fp32=False,
|
39 |
+
fused_add_norm=False,
|
40 |
+
layer_idx=None,
|
41 |
+
device=None,
|
42 |
+
dtype=None,
|
43 |
+
checkpoint_mixer=False,
|
44 |
+
):
|
45 |
+
if ssm_cfg is None:
|
46 |
+
ssm_cfg = {}
|
47 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
48 |
+
mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
|
49 |
+
norm_cls = partial(
|
50 |
+
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
|
51 |
+
)
|
52 |
+
block = Block(
|
53 |
+
d_model,
|
54 |
+
mixer_cls,
|
55 |
+
norm_cls=norm_cls,
|
56 |
+
fused_add_norm=fused_add_norm,
|
57 |
+
residual_in_fp32=residual_in_fp32,
|
58 |
+
)
|
59 |
+
block.layer_idx = layer_idx
|
60 |
+
if checkpoint_mixer:
|
61 |
+
block.mixer = CheckpointedModule(block.mixer)
|
62 |
+
return block
|
63 |
+
|
64 |
+
class CheckpointedModule(torch.nn.Module):
|
65 |
+
def __init__(self, layer):
|
66 |
+
super().__init__()
|
67 |
+
self.ckpt_layer = layer
|
68 |
+
|
69 |
+
def forward(self, x, *args, **kwargs):
|
70 |
+
return checkpoint(self.ckpt_layer, x, use_reentrant=False)
|
71 |
+
|
72 |
+
# def state_dict(self, **kwargs):
|
73 |
+
# # Get the state dict of the underlying layer
|
74 |
+
# layer_state_dict = self.ckpt_layer.state_dict(**kwargs)
|
75 |
+
# # Create a new state dict with the original keys
|
76 |
+
# state_dict = {k.replace('ckpt_layer.', ''): v for k, v in layer_state_dict.items()}
|
77 |
+
# return state_dict
|
78 |
+
|
79 |
+
class MixerModelSafe(MixerModel):
|
80 |
+
"""
|
81 |
+
Overwrite the forward method to allow saving intermediate layers.
|
82 |
+
"""
|
83 |
+
|
84 |
+
def forward(self, input_ids, inference_params=None, save_layer=[]):
|
85 |
+
hidden_states = self.embedding(input_ids)
|
86 |
+
residual = None
|
87 |
+
if len(save_layer) > 0:
|
88 |
+
hidden_states_dict = {}
|
89 |
+
for i, layer in enumerate(self.layers):
|
90 |
+
hidden_states, residual = layer(
|
91 |
+
hidden_states, residual, inference_params=inference_params
|
92 |
+
)
|
93 |
+
if i + 1 in save_layer:
|
94 |
+
hidden_states_dict[i + 1] = (
|
95 |
+
hidden_states.detach().cpu().to(torch.float).numpy()
|
96 |
+
)
|
97 |
+
if len(save_layer) > 0:
|
98 |
+
return hidden_states_dict
|
99 |
+
|
100 |
+
if not self.fused_add_norm:
|
101 |
+
residual = (
|
102 |
+
(hidden_states + residual) if residual is not None else hidden_states
|
103 |
+
)
|
104 |
+
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
105 |
+
else:
|
106 |
+
# Set prenorm=False here since we don't need the residual
|
107 |
+
fused_add_norm_fn = (
|
108 |
+
rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
|
109 |
+
)
|
110 |
+
hidden_states = fused_add_norm_fn(
|
111 |
+
hidden_states,
|
112 |
+
self.norm_f.weight,
|
113 |
+
self.norm_f.bias,
|
114 |
+
eps=self.norm_f.eps,
|
115 |
+
residual=residual,
|
116 |
+
prenorm=False,
|
117 |
+
residual_in_fp32=self.residual_in_fp32,
|
118 |
+
)
|
119 |
+
return hidden_states
|
120 |
+
|
121 |
+
class MixerModelWithPosids(nn.Module):
|
122 |
+
r"""Mixer model for Mamba but we add positional encodings to the input embeddings."""
|
123 |
+
|
124 |
+
def __init__(
|
125 |
+
self,
|
126 |
+
d_model: int,
|
127 |
+
n_layer: int,
|
128 |
+
vocab_size: int,
|
129 |
+
max_position_embeddings: int,
|
130 |
+
ssm_cfg=None,
|
131 |
+
norm_epsilon: float = 1e-5,
|
132 |
+
rms_norm: bool = False,
|
133 |
+
initializer_cfg=None,
|
134 |
+
fused_add_norm=False,
|
135 |
+
residual_in_fp32=False,
|
136 |
+
device=None,
|
137 |
+
dtype=None,
|
138 |
+
checkpoint_mixer=False,
|
139 |
+
) -> None:
|
140 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
141 |
+
super().__init__()
|
142 |
+
self.residual_in_fp32 = residual_in_fp32
|
143 |
+
|
144 |
+
self.embedding = nn.Embedding(vocab_size, d_model // 2, **factory_kwargs)
|
145 |
+
self.position_embedding = nn.Embedding(
|
146 |
+
max_position_embeddings, d_model - d_model // 2, **factory_kwargs
|
147 |
+
)
|
148 |
+
|
149 |
+
# We change the order of residual and layer norm:
|
150 |
+
# Instead of LN -> Attn / MLP -> Add, we do:
|
151 |
+
# Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
|
152 |
+
# the main branch (output of MLP / Mixer). The model definition is unchanged.
|
153 |
+
# This is for performance reason: we can fuse add + layer_norm.
|
154 |
+
self.fused_add_norm = fused_add_norm
|
155 |
+
if self.fused_add_norm:
|
156 |
+
if layer_norm_fn is None or rms_norm_fn is None:
|
157 |
+
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
|
158 |
+
|
159 |
+
self.layers = nn.ModuleList(
|
160 |
+
[
|
161 |
+
create_block(
|
162 |
+
d_model,
|
163 |
+
ssm_cfg=ssm_cfg,
|
164 |
+
norm_epsilon=norm_epsilon,
|
165 |
+
rms_norm=rms_norm,
|
166 |
+
residual_in_fp32=residual_in_fp32,
|
167 |
+
fused_add_norm=fused_add_norm,
|
168 |
+
layer_idx=i,
|
169 |
+
checkpoint_mixer=checkpoint_mixer,
|
170 |
+
**factory_kwargs,
|
171 |
+
)
|
172 |
+
for i in range(n_layer)
|
173 |
+
]
|
174 |
+
)
|
175 |
+
|
176 |
+
self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
|
177 |
+
d_model, eps=norm_epsilon, **factory_kwargs
|
178 |
+
)
|
179 |
+
|
180 |
+
self.apply(
|
181 |
+
partial(
|
182 |
+
_init_weights,
|
183 |
+
n_layer=n_layer,
|
184 |
+
**(initializer_cfg if initializer_cfg is not None else {}),
|
185 |
+
)
|
186 |
+
)
|
187 |
+
|
188 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
189 |
+
return {
|
190 |
+
i: layer.allocate_inference_cache(
|
191 |
+
batch_size, max_seqlen, dtype=dtype, **kwargs
|
192 |
+
)
|
193 |
+
for i, layer in enumerate(self.layers)
|
194 |
+
}
|
195 |
+
|
196 |
+
def forward(self, input_ids, position_ids, inference_params=None, save_layer=[]):
|
197 |
+
hidden_states = torch.cat(
|
198 |
+
[
|
199 |
+
self.embedding(input_ids),
|
200 |
+
self.position_embedding(position_ids),
|
201 |
+
],
|
202 |
+
-1,
|
203 |
+
)
|
204 |
+
residual = None
|
205 |
+
if len(save_layer) > 0:
|
206 |
+
hidden_states_dict = {}
|
207 |
+
for i, layer in enumerate(self.layers):
|
208 |
+
hidden_states, residual = layer(
|
209 |
+
hidden_states, residual, inference_params=inference_params
|
210 |
+
)
|
211 |
+
if i + 1 in save_layer:
|
212 |
+
hidden_states_dict[i + 1] = (
|
213 |
+
hidden_states.detach().cpu().to(torch.float).numpy()
|
214 |
+
)
|
215 |
+
if len(save_layer) > 0:
|
216 |
+
return hidden_states_dict
|
217 |
+
|
218 |
+
if not self.fused_add_norm:
|
219 |
+
residual = (
|
220 |
+
(hidden_states + residual) if residual is not None else hidden_states
|
221 |
+
)
|
222 |
+
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
223 |
+
else:
|
224 |
+
fused_add_norm_fn = (
|
225 |
+
rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
|
226 |
+
)
|
227 |
+
hidden_states = fused_add_norm_fn(
|
228 |
+
hidden_states,
|
229 |
+
self.norm_f.weight,
|
230 |
+
self.norm_f.bias,
|
231 |
+
eps=self.norm_f.eps,
|
232 |
+
residual=residual,
|
233 |
+
prenorm=False,
|
234 |
+
residual_in_fp32=self.residual_in_fp32,
|
235 |
+
)
|
236 |
+
return hidden_states
|
237 |
+
|
238 |
+
class MixerModelWith2DPosids(nn.Module):
|
239 |
+
r"""Mixer model for Mamba but we add positional encodings to the input embeddings."""
|
240 |
+
|
241 |
+
def __init__(
|
242 |
+
self,
|
243 |
+
d_model: int,
|
244 |
+
n_layer: int,
|
245 |
+
vocab_size: int,
|
246 |
+
max_position_embeddings: int,
|
247 |
+
max_sequence_position_embeddings: int = 512,
|
248 |
+
ssm_cfg=None,
|
249 |
+
norm_epsilon: float = 1e-5,
|
250 |
+
rms_norm: bool = False,
|
251 |
+
initializer_cfg=None,
|
252 |
+
fused_add_norm=False,
|
253 |
+
residual_in_fp32=False,
|
254 |
+
device=None,
|
255 |
+
dtype=None,
|
256 |
+
checkpoint_mixer=False,
|
257 |
+
) -> None:
|
258 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
259 |
+
super().__init__()
|
260 |
+
self.residual_in_fp32 = residual_in_fp32
|
261 |
+
|
262 |
+
self.embedding = nn.Embedding(
|
263 |
+
vocab_size, d_model - 2 * d_model // 4, **factory_kwargs
|
264 |
+
)
|
265 |
+
self.position_embedding = nn.Embedding(
|
266 |
+
max_position_embeddings, d_model // 4, **factory_kwargs
|
267 |
+
)
|
268 |
+
self.seq_position_embedding = nn.Embedding(
|
269 |
+
max_sequence_position_embeddings, d_model // 4, **factory_kwargs
|
270 |
+
)
|
271 |
+
self.d_embeddings = d_model - 2 * d_model // 4
|
272 |
+
|
273 |
+
# We change the order of residual and layer norm:
|
274 |
+
# Instead of LN -> Attn / MLP -> Add, we do:
|
275 |
+
# Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
|
276 |
+
# the main branch (output of MLP / Mixer). The model definition is unchanged.
|
277 |
+
# This is for performance reason: we can fuse add + layer_norm.
|
278 |
+
self.fused_add_norm = fused_add_norm
|
279 |
+
if self.fused_add_norm:
|
280 |
+
if layer_norm_fn is None or rms_norm_fn is None:
|
281 |
+
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
|
282 |
+
|
283 |
+
self.layers = nn.ModuleList(
|
284 |
+
[
|
285 |
+
create_block(
|
286 |
+
d_model,
|
287 |
+
ssm_cfg=ssm_cfg,
|
288 |
+
norm_epsilon=norm_epsilon,
|
289 |
+
rms_norm=rms_norm,
|
290 |
+
residual_in_fp32=residual_in_fp32,
|
291 |
+
fused_add_norm=fused_add_norm,
|
292 |
+
layer_idx=i,
|
293 |
+
checkpoint_mixer=checkpoint_mixer,
|
294 |
+
**factory_kwargs,
|
295 |
+
)
|
296 |
+
for i in range(n_layer)
|
297 |
+
]
|
298 |
+
)
|
299 |
+
|
300 |
+
self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
|
301 |
+
d_model, eps=norm_epsilon, **factory_kwargs
|
302 |
+
)
|
303 |
+
|
304 |
+
self.apply(
|
305 |
+
partial(
|
306 |
+
_init_weights,
|
307 |
+
n_layer=n_layer,
|
308 |
+
**(initializer_cfg if initializer_cfg is not None else {}),
|
309 |
+
)
|
310 |
+
)
|
311 |
+
|
312 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
313 |
+
return {
|
314 |
+
i: layer.allocate_inference_cache(
|
315 |
+
batch_size, max_seqlen, dtype=dtype, **kwargs
|
316 |
+
)
|
317 |
+
for i, layer in enumerate(self.layers)
|
318 |
+
}
|
319 |
+
|
320 |
+
def forward(
|
321 |
+
self,
|
322 |
+
input_ids,
|
323 |
+
position_ids,
|
324 |
+
seq_position_ids,
|
325 |
+
inference_params=None,
|
326 |
+
save_layer=[],
|
327 |
+
):
|
328 |
+
hidden_states = torch.cat(
|
329 |
+
[
|
330 |
+
self.embedding(input_ids),
|
331 |
+
self.position_embedding(position_ids),
|
332 |
+
self.seq_position_embedding(seq_position_ids),
|
333 |
+
],
|
334 |
+
-1,
|
335 |
+
)
|
336 |
+
residual = None
|
337 |
+
if len(save_layer) > 0:
|
338 |
+
hidden_states_dict = {}
|
339 |
+
for i, layer in enumerate(self.layers):
|
340 |
+
hidden_states, residual = layer(
|
341 |
+
hidden_states, residual, inference_params=inference_params
|
342 |
+
)
|
343 |
+
if i + 1 in save_layer:
|
344 |
+
hidden_states_dict[i + 1] = (
|
345 |
+
hidden_states.detach().cpu().to(torch.float).numpy()
|
346 |
+
)
|
347 |
+
if len(save_layer) > 0:
|
348 |
+
return hidden_states_dict
|
349 |
+
|
350 |
+
if not self.fused_add_norm:
|
351 |
+
residual = (
|
352 |
+
(hidden_states + residual) if residual is not None else hidden_states
|
353 |
+
)
|
354 |
+
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
355 |
+
else:
|
356 |
+
fused_add_norm_fn = (
|
357 |
+
rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
|
358 |
+
)
|
359 |
+
hidden_states = fused_add_norm_fn(
|
360 |
+
hidden_states,
|
361 |
+
self.norm_f.weight,
|
362 |
+
self.norm_f.bias,
|
363 |
+
eps=self.norm_f.eps,
|
364 |
+
residual=residual,
|
365 |
+
prenorm=False,
|
366 |
+
residual_in_fp32=self.residual_in_fp32,
|
367 |
+
)
|
368 |
+
return hidden_states
|
369 |
+
|
370 |
+
class MambaLMHeadModelSafe(nn.Module, GenerationMixinSafe):
|
371 |
+
|
372 |
+
def __init__(
|
373 |
+
self,
|
374 |
+
config: MambaConfig,
|
375 |
+
initializer_cfg=None,
|
376 |
+
device=None,
|
377 |
+
dtype=None,
|
378 |
+
checkpoint_mixer=False,
|
379 |
+
) -> None:
|
380 |
+
self.config = config
|
381 |
+
d_model = config.d_model
|
382 |
+
n_layer = config.n_layer
|
383 |
+
vocab_size = config.vocab_size
|
384 |
+
ssm_cfg = config.ssm_cfg
|
385 |
+
rms_norm = config.rms_norm
|
386 |
+
residual_in_fp32 = config.residual_in_fp32
|
387 |
+
fused_add_norm = config.fused_add_norm
|
388 |
+
pad_vocab_size_multiple = config.pad_vocab_size_multiple
|
389 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
390 |
+
if checkpoint_mixer:
|
391 |
+
raise NotImplementedError(
|
392 |
+
"Checkpointing is not yet supported for MambaLMHeadModelSafe"
|
393 |
+
)
|
394 |
+
|
395 |
+
super().__init__()
|
396 |
+
if vocab_size % pad_vocab_size_multiple != 0:
|
397 |
+
vocab_size += pad_vocab_size_multiple - (
|
398 |
+
vocab_size % pad_vocab_size_multiple
|
399 |
+
)
|
400 |
+
self.backbone = MixerModelSafe(
|
401 |
+
d_model=d_model,
|
402 |
+
n_layer=n_layer,
|
403 |
+
vocab_size=vocab_size,
|
404 |
+
ssm_cfg=ssm_cfg,
|
405 |
+
rms_norm=rms_norm,
|
406 |
+
initializer_cfg=initializer_cfg,
|
407 |
+
fused_add_norm=fused_add_norm,
|
408 |
+
residual_in_fp32=residual_in_fp32,
|
409 |
+
**factory_kwargs,
|
410 |
+
)
|
411 |
+
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
412 |
+
|
413 |
+
# Initialize weights and apply final processing
|
414 |
+
self.apply(
|
415 |
+
partial(
|
416 |
+
_init_weights,
|
417 |
+
n_layer=n_layer,
|
418 |
+
**(initializer_cfg if initializer_cfg is not None else {}),
|
419 |
+
)
|
420 |
+
)
|
421 |
+
self.tie_weights()
|
422 |
+
|
423 |
+
def tie_weights(self):
|
424 |
+
self.lm_head.weight = self.backbone.embedding.weight
|
425 |
+
|
426 |
+
def clip_grad_norm_(self, max_norm, norm_type=2.0):
|
427 |
+
r"""Clip the norm of the gradients for the model.
|
428 |
+
Args:
|
429 |
+
max_norm (float or int): The maximum norm of the gradients.
|
430 |
+
The gradients are modified in-place.
|
431 |
+
norm_type (float or int): The type of the used p-norm. Can be 'inf' for infinity norm.
|
432 |
+
Returns:
|
433 |
+
Total norm of the parameters (viewed as a single vector).
|
434 |
+
"""
|
435 |
+
return torch.nn.utils.clip_grad_value_(self.parameters(), max_norm)
|
436 |
+
|
437 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
438 |
+
return self.backbone.allocate_inference_cache(
|
439 |
+
batch_size, max_seqlen, dtype=dtype, **kwargs
|
440 |
+
)
|
441 |
+
|
442 |
+
def forward(
|
443 |
+
self,
|
444 |
+
input_ids,
|
445 |
+
position_ids=None,
|
446 |
+
inference_params=None,
|
447 |
+
num_last_tokens=0,
|
448 |
+
save_layer=[],
|
449 |
+
*args,
|
450 |
+
**kwargs,
|
451 |
+
):
|
452 |
+
"""
|
453 |
+
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
454 |
+
num_last_tokens: if > 0, only return the logits for the last n tokens
|
455 |
+
"""
|
456 |
+
return self.protected_forward(
|
457 |
+
input_ids, position_ids, inference_params, num_last_tokens, save_layer
|
458 |
+
)
|
459 |
+
|
460 |
+
def protected_forward(
|
461 |
+
self,
|
462 |
+
input_ids,
|
463 |
+
position_ids=None,
|
464 |
+
inference_params=None,
|
465 |
+
num_last_tokens=0,
|
466 |
+
save_layer=[],
|
467 |
+
):
|
468 |
+
hidden_states = self.backbone(
|
469 |
+
input_ids, inference_params=inference_params, save_layer=save_layer
|
470 |
+
)
|
471 |
+
if len(save_layer) > 0:
|
472 |
+
return hidden_states
|
473 |
+
if num_last_tokens > 0:
|
474 |
+
hidden_states = hidden_states[:, -num_last_tokens:]
|
475 |
+
lm_logits = self.lm_head(hidden_states)
|
476 |
+
CausalLMOutput = namedtuple("CausalLMOutput", ["loss", "logits"])
|
477 |
+
return CausalLMOutput(loss=None, logits=lm_logits)
|
478 |
+
|
479 |
+
@classmethod
|
480 |
+
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
|
481 |
+
config_data = load_config_hf(pretrained_model_name)
|
482 |
+
config = MambaConfig(**config_data)
|
483 |
+
model = cls(config, device=device, dtype=dtype, **kwargs)
|
484 |
+
model.load_state_dict(
|
485 |
+
load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype),
|
486 |
+
strict=False,
|
487 |
+
)
|
488 |
+
return model
|
489 |
+
|
490 |
+
def save_pretrained(self, save_directory):
|
491 |
+
"""
|
492 |
+
Minimal implementation of save_pretrained for MambaLMHeadModel.
|
493 |
+
Save the model and its configuration file to a directory.
|
494 |
+
"""
|
495 |
+
# Ensure save_directory exists
|
496 |
+
os.makedirs(save_directory, exist_ok=True)
|
497 |
+
|
498 |
+
# Save the model's state_dict
|
499 |
+
model_path = os.path.join(save_directory, "pytorch_model.bin")
|
500 |
+
torch.save(self.state_dict(), model_path)
|
501 |
+
|
502 |
+
# Save the configuration of the model
|
503 |
+
config_path = os.path.join(save_directory, "config.json")
|
504 |
+
with open(config_path, "w") as f:
|
505 |
+
json.dump(self.config.__dict__, f)
|
506 |
+
|
507 |
+
class MambaLMHeadModelwithPosids(nn.Module, GenerationMixinSafe):
|
508 |
+
|
509 |
+
def __init__(
|
510 |
+
self,
|
511 |
+
config: MambaConfig,
|
512 |
+
initializer_cfg=None,
|
513 |
+
device=None,
|
514 |
+
dtype=None,
|
515 |
+
checkpoint_mixer=False,
|
516 |
+
) -> None:
|
517 |
+
self.config = config
|
518 |
+
d_model = config.d_model
|
519 |
+
n_layer = config.n_layer
|
520 |
+
vocab_size = config.vocab_size
|
521 |
+
max_position_embeddings = config.max_position_embeddings
|
522 |
+
ssm_cfg = config.ssm_cfg
|
523 |
+
rms_norm = config.rms_norm
|
524 |
+
residual_in_fp32 = config.residual_in_fp32
|
525 |
+
fused_add_norm = config.fused_add_norm
|
526 |
+
pad_vocab_size_multiple = config.pad_vocab_size_multiple
|
527 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
528 |
+
|
529 |
+
super().__init__()
|
530 |
+
if vocab_size % pad_vocab_size_multiple != 0:
|
531 |
+
vocab_size += pad_vocab_size_multiple - (
|
532 |
+
vocab_size % pad_vocab_size_multiple
|
533 |
+
)
|
534 |
+
self.backbone = MixerModelWithPosids(
|
535 |
+
d_model=d_model,
|
536 |
+
n_layer=n_layer,
|
537 |
+
vocab_size=vocab_size,
|
538 |
+
max_position_embeddings=max_position_embeddings,
|
539 |
+
ssm_cfg=ssm_cfg,
|
540 |
+
rms_norm=rms_norm,
|
541 |
+
initializer_cfg=initializer_cfg,
|
542 |
+
fused_add_norm=fused_add_norm,
|
543 |
+
residual_in_fp32=residual_in_fp32,
|
544 |
+
checkpoint_mixer=checkpoint_mixer,
|
545 |
+
**factory_kwargs,
|
546 |
+
)
|
547 |
+
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
548 |
+
|
549 |
+
# Initialize weights and apply final processing
|
550 |
+
self.apply(
|
551 |
+
partial(
|
552 |
+
_init_weights,
|
553 |
+
n_layer=n_layer,
|
554 |
+
**(initializer_cfg if initializer_cfg is not None else {}),
|
555 |
+
)
|
556 |
+
)
|
557 |
+
self.tie_weights()
|
558 |
+
|
559 |
+
def tie_weights(self):
|
560 |
+
self.lm_head.weight = self.backbone.embedding.weight
|
561 |
+
|
562 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
563 |
+
return self.backbone.allocate_inference_cache(
|
564 |
+
batch_size, max_seqlen, dtype=dtype, **kwargs
|
565 |
+
)
|
566 |
+
|
567 |
+
def forward(
|
568 |
+
self,
|
569 |
+
input_ids,
|
570 |
+
position_ids=None,
|
571 |
+
inference_params=None,
|
572 |
+
num_last_tokens=0,
|
573 |
+
save_layer=[],
|
574 |
+
*args,
|
575 |
+
**kwargs,
|
576 |
+
):
|
577 |
+
"""
|
578 |
+
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
579 |
+
num_last_tokens: if > 0, only return the logits for the last n tokens
|
580 |
+
"""
|
581 |
+
return self.protected_forward(
|
582 |
+
input_ids, position_ids, inference_params, num_last_tokens, save_layer
|
583 |
+
)
|
584 |
+
|
585 |
+
def protected_forward(
|
586 |
+
self,
|
587 |
+
input_ids,
|
588 |
+
position_ids=None,
|
589 |
+
inference_params=None,
|
590 |
+
num_last_tokens=0,
|
591 |
+
save_layer=[],
|
592 |
+
):
|
593 |
+
hidden_states = self.backbone(
|
594 |
+
input_ids,
|
595 |
+
position_ids=position_ids,
|
596 |
+
inference_params=inference_params,
|
597 |
+
save_layer=save_layer,
|
598 |
+
)
|
599 |
+
if len(save_layer) > 0:
|
600 |
+
return hidden_states
|
601 |
+
hidden_states = hidden_states[:, :, : self.config.d_model // 2]
|
602 |
+
if num_last_tokens > 0:
|
603 |
+
hidden_states = hidden_states[:, -num_last_tokens:]
|
604 |
+
lm_logits = self.lm_head(hidden_states)
|
605 |
+
CausalLMOutput = namedtuple("CausalLMOutput", ["loss", "logits"])
|
606 |
+
return CausalLMOutput(loss=None, logits=lm_logits)
|
607 |
+
|
608 |
+
@classmethod
|
609 |
+
def from_pretrained(
|
610 |
+
cls,
|
611 |
+
pretrained_model_name,
|
612 |
+
device=None,
|
613 |
+
dtype=None,
|
614 |
+
checkpoint_mixer=False,
|
615 |
+
**kwargs,
|
616 |
+
):
|
617 |
+
config_data = load_config_hf(pretrained_model_name)
|
618 |
+
config = MambaConfig(**config_data)
|
619 |
+
model = cls(
|
620 |
+
config,
|
621 |
+
device=device,
|
622 |
+
dtype=dtype,
|
623 |
+
checkpoint_mixer=checkpoint_mixer,
|
624 |
+
**kwargs,
|
625 |
+
)
|
626 |
+
state_dict = load_state_dict_hf(
|
627 |
+
pretrained_model_name, device=device, dtype=dtype
|
628 |
+
)
|
629 |
+
if state_dict.keys() != model.state_dict().keys():
|
630 |
+
if checkpoint_mixer:
|
631 |
+
for key in model.state_dict().keys():
|
632 |
+
if "ckpt_layer" in key:
|
633 |
+
state_dict[key] = state_dict.pop(key.replace("ckpt_layer.", ""))
|
634 |
+
print(
|
635 |
+
"Using a model that was pretrained without gradient checkpointing and now want to use it. Changed the keys of the state_dict to match the model's keys."
|
636 |
+
)
|
637 |
+
else:
|
638 |
+
for key in list(state_dict.keys()):
|
639 |
+
if "ckpt_layer" in key:
|
640 |
+
state_dict[key.replace("ckpt_layer.", "")] = state_dict.pop(key)
|
641 |
+
print(
|
642 |
+
"Using a model that was pretrained with gradient checkpointing but now do not want to use it. Changed the keys of the state_dict to match the model's keys."
|
643 |
+
)
|
644 |
+
assert (
|
645 |
+
state_dict.keys() == model.state_dict().keys()
|
646 |
+
), "The keys of the state_dict do not match the model's keys."
|
647 |
+
model.load_state_dict(state_dict)
|
648 |
+
return model
|
649 |
+
|
650 |
+
def save_pretrained(self, save_directory):
|
651 |
+
"""
|
652 |
+
Minimal implementation of save_pretrained for MambaLMHeadModel.
|
653 |
+
Save the model and its configuration file to a directory.
|
654 |
+
"""
|
655 |
+
# Ensure save_directory exists
|
656 |
+
os.makedirs(save_directory, exist_ok=True)
|
657 |
+
|
658 |
+
# Save the model's state_dict
|
659 |
+
model_path = os.path.join(save_directory, "pytorch_model.bin")
|
660 |
+
torch.save(self.state_dict(), model_path)
|
661 |
+
|
662 |
+
# Save the configuration of the model
|
663 |
+
config_path = os.path.join(save_directory, "config.json")
|
664 |
+
with open(config_path, "w") as f:
|
665 |
+
json.dump(self.config.__dict__, f)
|
666 |
+
|
667 |
+
class MambaLMHeadModelwith2DPosids(nn.Module, GenerationMixinSafe):
|
668 |
+
|
669 |
+
def __init__(
|
670 |
+
self,
|
671 |
+
config: MambaConfig,
|
672 |
+
initializer_cfg=None,
|
673 |
+
device=None,
|
674 |
+
dtype=None,
|
675 |
+
checkpoint_mixer=False,
|
676 |
+
) -> None:
|
677 |
+
self.config = config
|
678 |
+
d_model = config.d_model
|
679 |
+
n_layer = config.n_layer
|
680 |
+
vocab_size = config.vocab_size
|
681 |
+
max_position_embeddings = config.max_position_embeddings
|
682 |
+
ssm_cfg = config.ssm_cfg
|
683 |
+
rms_norm = config.rms_norm
|
684 |
+
residual_in_fp32 = config.residual_in_fp32
|
685 |
+
fused_add_norm = config.fused_add_norm
|
686 |
+
pad_vocab_size_multiple = config.pad_vocab_size_multiple
|
687 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
688 |
+
|
689 |
+
super().__init__()
|
690 |
+
if vocab_size % pad_vocab_size_multiple != 0:
|
691 |
+
vocab_size += pad_vocab_size_multiple - (
|
692 |
+
vocab_size % pad_vocab_size_multiple
|
693 |
+
)
|
694 |
+
self.backbone = MixerModelWith2DPosids(
|
695 |
+
d_model=d_model,
|
696 |
+
n_layer=n_layer,
|
697 |
+
vocab_size=vocab_size,
|
698 |
+
max_position_embeddings=max_position_embeddings,
|
699 |
+
ssm_cfg=ssm_cfg,
|
700 |
+
rms_norm=rms_norm,
|
701 |
+
initializer_cfg=initializer_cfg,
|
702 |
+
fused_add_norm=fused_add_norm,
|
703 |
+
residual_in_fp32=residual_in_fp32,
|
704 |
+
checkpoint_mixer=checkpoint_mixer,
|
705 |
+
**factory_kwargs,
|
706 |
+
)
|
707 |
+
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
708 |
+
|
709 |
+
# Initialize weights and apply final processing
|
710 |
+
self.apply(
|
711 |
+
partial(
|
712 |
+
_init_weights,
|
713 |
+
n_layer=n_layer,
|
714 |
+
**(initializer_cfg if initializer_cfg is not None else {}),
|
715 |
+
)
|
716 |
+
)
|
717 |
+
self.tie_weights()
|
718 |
+
|
719 |
+
def tie_weights(self):
|
720 |
+
self.lm_head.weight = self.backbone.embedding.weight
|
721 |
+
|
722 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
723 |
+
return self.backbone.allocate_inference_cache(
|
724 |
+
batch_size, max_seqlen, dtype=dtype, **kwargs
|
725 |
+
)
|
726 |
+
|
727 |
+
def forward(
|
728 |
+
self,
|
729 |
+
input_ids,
|
730 |
+
position_ids=None,
|
731 |
+
seq_position_ids=None,
|
732 |
+
inference_params=None,
|
733 |
+
num_last_tokens=0,
|
734 |
+
save_layer=[],
|
735 |
+
*args,
|
736 |
+
**kwargs,
|
737 |
+
):
|
738 |
+
"""
|
739 |
+
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
740 |
+
num_last_tokens: if > 0, only return the logits for the last n tokens
|
741 |
+
"""
|
742 |
+
return self.protected_forward(
|
743 |
+
input_ids,
|
744 |
+
position_ids,
|
745 |
+
seq_position_ids,
|
746 |
+
inference_params,
|
747 |
+
num_last_tokens,
|
748 |
+
save_layer,
|
749 |
+
)
|
750 |
+
|
751 |
+
def protected_forward(
|
752 |
+
self,
|
753 |
+
input_ids,
|
754 |
+
position_ids=None,
|
755 |
+
seq_position_ids=None,
|
756 |
+
inference_params=None,
|
757 |
+
num_last_tokens=0,
|
758 |
+
save_layer=[],
|
759 |
+
):
|
760 |
+
hidden_states = self.backbone(
|
761 |
+
input_ids,
|
762 |
+
position_ids=position_ids,
|
763 |
+
seq_position_ids=seq_position_ids,
|
764 |
+
inference_params=inference_params,
|
765 |
+
save_layer=save_layer,
|
766 |
+
)
|
767 |
+
if len(save_layer) > 0:
|
768 |
+
return hidden_states
|
769 |
+
hidden_states = hidden_states[:, :, : self.backbone.d_embeddings]
|
770 |
+
if num_last_tokens > 0:
|
771 |
+
hidden_states = hidden_states[:, -num_last_tokens:]
|
772 |
+
lm_logits = self.lm_head(hidden_states)
|
773 |
+
CausalLMOutput = namedtuple("CausalLMOutput", ["loss", "logits"])
|
774 |
+
return CausalLMOutput(loss=None, logits=lm_logits)
|
775 |
+
|
776 |
+
@classmethod
|
777 |
+
def from_pretrained(
|
778 |
+
cls,
|
779 |
+
pretrained_model_name,
|
780 |
+
device=None,
|
781 |
+
dtype=None,
|
782 |
+
checkpoint_mixer=False,
|
783 |
+
**kwargs,
|
784 |
+
):
|
785 |
+
config_data = load_config_hf(pretrained_model_name)
|
786 |
+
config = MambaConfig(**config_data)
|
787 |
+
model = cls(
|
788 |
+
config,
|
789 |
+
device=device,
|
790 |
+
dtype=dtype,
|
791 |
+
checkpoint_mixer=checkpoint_mixer,
|
792 |
+
**kwargs,
|
793 |
+
)
|
794 |
+
state_dict = load_state_dict_hf(
|
795 |
+
pretrained_model_name, device=device, dtype=dtype
|
796 |
+
)
|
797 |
+
if state_dict.keys() != model.state_dict().keys():
|
798 |
+
if checkpoint_mixer:
|
799 |
+
for key in model.state_dict().keys():
|
800 |
+
if "ckpt_layer" in key:
|
801 |
+
state_dict[key] = state_dict.pop(key.replace("ckpt_layer.", ""))
|
802 |
+
print(
|
803 |
+
"Using a model that was pretrained without gradient checkpointing and now want to use it. Changed the keys of the state_dict to match the model's keys."
|
804 |
+
)
|
805 |
+
else:
|
806 |
+
for key in list(state_dict.keys()):
|
807 |
+
if "ckpt_layer" in key:
|
808 |
+
state_dict[key.replace("ckpt_layer.", "")] = state_dict.pop(key)
|
809 |
+
print(
|
810 |
+
"Using a model that was pretrained with gradient checkpointing but now do not want to use it. Changed the keys of the state_dict to match the model's keys."
|
811 |
+
)
|
812 |
+
assert (
|
813 |
+
state_dict.keys() == model.state_dict().keys()
|
814 |
+
), "The keys of the state_dict do not match the model's keys."
|
815 |
+
model.load_state_dict(state_dict)
|
816 |
+
return model
|
817 |
+
|
818 |
+
def save_pretrained(self, save_directory):
|
819 |
+
"""
|
820 |
+
Minimal implementation of save_pretrained for MambaLMHeadModel.
|
821 |
+
Save the model and its configuration file to a directory.
|
822 |
+
"""
|
823 |
+
# Ensure save_directory exists
|
824 |
+
os.makedirs(save_directory, exist_ok=True)
|
825 |
+
|
826 |
+
# Save the model's state_dict
|
827 |
+
model_path = os.path.join(save_directory, "pytorch_model.bin")
|
828 |
+
torch.save(self.state_dict(), model_path)
|
829 |
+
|
830 |
+
# Save the configuration of the model
|
831 |
+
config_path = os.path.join(save_directory, "config.json")
|
832 |
+
with open(config_path, "w") as f:
|
833 |
+
json.dump(self.config.__dict__, f)
|
protxlstm/models/xlstm.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__all__ = [
|
2 |
+
"xLSTMConfig",
|
3 |
+
"xLSTMLMHeadModel",
|
4 |
+
]
|
5 |
+
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
from collections import namedtuple
|
9 |
+
from dataclasses import asdict
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from dacite import Config as DaciteConfig, from_dict
|
14 |
+
from omegaconf import OmegaConf
|
15 |
+
from transformers import PretrainedConfig
|
16 |
+
|
17 |
+
from protxlstm.generation import GenerationMixinSafe
|
18 |
+
from protxlstm.utils import load_config_hf, load_state_dict_hf
|
19 |
+
from protxlstm.xlstm.xlstm_lm_model import xLSTMLMModel, xLSTMLMModelConfig
|
20 |
+
|
21 |
+
|
22 |
+
class xLSTMConfig(PretrainedConfig):
|
23 |
+
|
24 |
+
def __init__(self):
|
25 |
+
self.config_dataclass = xLSTMLMModelConfig()
|
26 |
+
|
27 |
+
def init_from_dict(self, config: dict):
|
28 |
+
config = OmegaConf.create(config)
|
29 |
+
self.config_dataclass = from_dict(
|
30 |
+
data_class=xLSTMLMModelConfig,
|
31 |
+
data=OmegaConf.to_container(config),
|
32 |
+
config=DaciteConfig(strict=True),
|
33 |
+
)
|
34 |
+
return self
|
35 |
+
|
36 |
+
def to_dict(self):
|
37 |
+
return asdict(self.config_dataclass)
|
38 |
+
|
39 |
+
|
40 |
+
class xLSTMLMHeadModel(nn.Module, GenerationMixinSafe):
|
41 |
+
|
42 |
+
def __init__(self, config: xLSTMConfig) -> None:
|
43 |
+
super().__init__()
|
44 |
+
|
45 |
+
self.config = config
|
46 |
+
self.backbone = xLSTMLMModel(self.config.config_dataclass)
|
47 |
+
self.backbone.reset_parameters()
|
48 |
+
|
49 |
+
self.setup()
|
50 |
+
|
51 |
+
|
52 |
+
def setup(self):
|
53 |
+
|
54 |
+
if 'LOCAL_RANK' in os.environ:
|
55 |
+
current_device = int(os.environ['LOCAL_RANK'])
|
56 |
+
else:
|
57 |
+
if 'SLURM_LOCALID' in os.environ:
|
58 |
+
current_device = int(os.environ['SLURM_LOCALID'])
|
59 |
+
else:
|
60 |
+
current_device = 0
|
61 |
+
|
62 |
+
#torch.cuda.set_device(f'cuda:{current_device}')
|
63 |
+
|
64 |
+
#self.backbone = self.backbone.to("cuda")
|
65 |
+
|
66 |
+
|
67 |
+
def forward(
|
68 |
+
self,
|
69 |
+
input_ids,
|
70 |
+
state=None,
|
71 |
+
position_ids=None,
|
72 |
+
seq_position_ids=None,
|
73 |
+
inference_params=None,
|
74 |
+
num_last_tokens=0,
|
75 |
+
save_layer=[],
|
76 |
+
**kwargs,
|
77 |
+
):
|
78 |
+
|
79 |
+
if self.config.config_dataclass.mlstm_block.mlstm.return_last_state:
|
80 |
+
lm_logits, state = self.backbone(input_ids, position_ids=position_ids, seq_position_ids=seq_position_ids, state=state)
|
81 |
+
CausalLMOutput = namedtuple("CausalLMOutput", ["loss", "logits", "state"])
|
82 |
+
return CausalLMOutput(loss=None, logits=lm_logits, state=state)
|
83 |
+
else:
|
84 |
+
lm_logits = self.backbone(input_ids, position_ids=position_ids, seq_position_ids=seq_position_ids, state=state)
|
85 |
+
CausalLMOutput = namedtuple("CausalLMOutput", ["loss", "logits"])
|
86 |
+
return CausalLMOutput(loss=None, logits=lm_logits)
|
87 |
+
|
88 |
+
def step(
|
89 |
+
self,
|
90 |
+
input_ids,
|
91 |
+
state=None,
|
92 |
+
position_ids=None,
|
93 |
+
seq_position_ids=None,
|
94 |
+
inference_params=None,
|
95 |
+
num_last_tokens=0,
|
96 |
+
save_layer=[],
|
97 |
+
**kwargs,
|
98 |
+
):
|
99 |
+
|
100 |
+
lm_logits, state = self.backbone.step(
|
101 |
+
input_ids, state=state, position_ids=position_ids, seq_position_ids=seq_position_ids
|
102 |
+
)
|
103 |
+
|
104 |
+
return lm_logits, state
|
105 |
+
|
106 |
+
|
107 |
+
@classmethod
|
108 |
+
def from_pretrained(
|
109 |
+
cls,
|
110 |
+
pretrained_model_name,
|
111 |
+
device=None,
|
112 |
+
dtype=None,
|
113 |
+
mlstm_backend=None,
|
114 |
+
mlstm_chunksize=None,
|
115 |
+
checkpoint_blocks=None,
|
116 |
+
rope_base_frequency=None,
|
117 |
+
mlstm_return_last_state=None,
|
118 |
+
):
|
119 |
+
# Load the checkpoint config
|
120 |
+
config_dict = load_config_hf(pretrained_model_name)
|
121 |
+
|
122 |
+
# update rope base frequency
|
123 |
+
if rope_base_frequency is not None and config_dict.get("rope_base_frequency", None) != rope_base_frequency:
|
124 |
+
config_dict["rope_base_frequency"] = rope_base_frequency
|
125 |
+
|
126 |
+
# update mlstm backend
|
127 |
+
if mlstm_backend is not None and config_dict["mlstm_block"]["mlstm"].get("backend", None) != mlstm_backend:
|
128 |
+
assert mlstm_backend in ["chunkwise", "chunkwise_variable", "parallel"], "invalid mlstm backend."
|
129 |
+
config_dict["mlstm_block"]["mlstm"]["backend"] = mlstm_backend
|
130 |
+
|
131 |
+
# update mlstm chunksize
|
132 |
+
if mlstm_chunksize is not None and config_dict["mlstm_block"]["mlstm"].get("chunk_size", None) != mlstm_chunksize:
|
133 |
+
config_dict["mlstm_block"]["mlstm"]["chunk_size"] = mlstm_chunksize
|
134 |
+
|
135 |
+
# update activation checkpointing
|
136 |
+
if checkpoint_blocks is not None:
|
137 |
+
config_dict["checkpoint_blocks"] = checkpoint_blocks
|
138 |
+
|
139 |
+
if mlstm_return_last_state is not None:
|
140 |
+
config_dict["mlstm_block"]["mlstm"]["return_last_state"] = mlstm_return_last_state
|
141 |
+
|
142 |
+
if "slstm_block" in config_dict:
|
143 |
+
config_dict.pop("slstm_block")
|
144 |
+
|
145 |
+
if "slstm_at" in config_dict:
|
146 |
+
config_dict.pop("slstm_at")
|
147 |
+
|
148 |
+
config = xLSTMConfig().init_from_dict(config_dict)
|
149 |
+
|
150 |
+
model = cls(config)
|
151 |
+
|
152 |
+
state_dict = load_state_dict_hf(
|
153 |
+
pretrained_model_name, device=device, dtype=dtype
|
154 |
+
)
|
155 |
+
assert (
|
156 |
+
state_dict.keys() == model.state_dict().keys()
|
157 |
+
), "The keys of the state_dict do not match the model's keys."
|
158 |
+
|
159 |
+
model.load_state_dict(state_dict)
|
160 |
+
|
161 |
+
return model
|
162 |
+
|
163 |
+
def save_pretrained(self, save_directory):
|
164 |
+
"""
|
165 |
+
Save the model and its configuration file to a directory.
|
166 |
+
"""
|
167 |
+
|
168 |
+
# Ensure save_directory exists
|
169 |
+
os.makedirs(save_directory, exist_ok=True)
|
170 |
+
|
171 |
+
# Save the model's state_dict
|
172 |
+
model_path = os.path.join(save_directory, "pytorch_model.bin")
|
173 |
+
torch.save(self.state_dict(), model_path)
|
174 |
+
|
175 |
+
# Save the configuration of the model
|
176 |
+
config_path = os.path.join(save_directory, "config.json")
|
177 |
+
with open(config_path, "w") as f:
|
178 |
+
json.dump(self.config.to_dict(), f)
|
179 |
+
|
180 |
+
|
protxlstm/plot_utils.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
cd = { # use dependent on model-type!!
|
3 |
+
"xLSTM": "#3073AD",
|
4 |
+
"Transformers": "#4B9D7A",
|
5 |
+
"Mamba": "#DF8953",
|
6 |
+
"S4": "#D275AB",
|
7 |
+
"Hyena": "#E86A61",
|
8 |
+
}
|
9 |
+
|
10 |
+
def setup_matplotlib():
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
from tueplots import bundles, axes
|
13 |
+
bundles.icml2022()
|
14 |
+
plt.rcParams.update(bundles.icml2022())
|
15 |
+
plt.rcParams.update(axes.lines(base_width=0.5))
|
16 |
+
plt.rcParams["text.usetex"] = False
|
17 |
+
plt.rcParams['font.family'] = "sans-serif"
|
18 |
+
plt.rcParams['font.serif'] = 'Arial'
|
19 |
+
plt.rcParams['legend.edgecolor'] = 'grey'
|
20 |
+
plt.rcParams['legend.framealpha'] = 0.7
|
21 |
+
plt.rcParams['lines.linewidth'] = 1.2
|
22 |
+
plt.rcParams['axes.grid'] = True
|
23 |
+
plt.rcParams['axes.grid.axis'] = 'both'
|
24 |
+
plt.rcParams['grid.alpha'] = 0.2
|
25 |
+
plt.rcParams['axes.grid'] = True
|
26 |
+
plt.rcParams['axes.prop_cycle'] = plt.cycler(color=cd.values())
|
protxlstm/train.py
ADDED
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Original code from ProtMamba under Apache License 2.0.
|
2 |
+
#
|
3 |
+
# Modifications made by Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen
|
4 |
+
# - Extended to training of xlstm and transformer-based models
|
5 |
+
# - Predefined splits instead of on-the-fly creation
|
6 |
+
# - Option to overwrite config parameters from the command line
|
7 |
+
# - wandb logging
|
8 |
+
|
9 |
+
import argparse
|
10 |
+
import os
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from omegaconf import OmegaConf
|
14 |
+
from transformers import TrainingArguments
|
15 |
+
|
16 |
+
from protxlstm.dataloaders import ProteinMemmapDataset, ProteinDataCollator
|
17 |
+
from protxlstm.models.xlstm import xLSTMConfig, xLSTMLMHeadModel
|
18 |
+
from protxlstm.models.llama import TransformerConfig, TransformerLMHeadModel
|
19 |
+
from protxlstm.trainer import ProtTrainer, EarlyStoppingCallback, get_last_checkpoint
|
20 |
+
from protxlstm.utils import (
|
21 |
+
AA_TO_ID,
|
22 |
+
compute_metrics,
|
23 |
+
is_zero_rank,
|
24 |
+
parse_override_args,
|
25 |
+
print_number_of_parameters,
|
26 |
+
print_zero_rank,
|
27 |
+
set_optimizer_and_scheduler,
|
28 |
+
setup_wandb,
|
29 |
+
load_model,
|
30 |
+
)
|
31 |
+
|
32 |
+
def run(config):
|
33 |
+
"""
|
34 |
+
Run training loop.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
config (dict): dictionary with the configuration parameters.
|
38 |
+
"""
|
39 |
+
|
40 |
+
if config.model_type == 'llama':
|
41 |
+
pe_kwargs = {
|
42 |
+
'max_position_embeddings' : config["model"]["max_position_embeddings"],
|
43 |
+
'add_position_ids' : '1d',
|
44 |
+
}
|
45 |
+
elif config.model_type == 'mamba':
|
46 |
+
from protxlstm.models.mamba import MambaConfig, MambaLMHeadModelSafe, MambaLMHeadModelwithPosids, MambaLMHeadModelwith2DPosids
|
47 |
+
pe_kwargs = {
|
48 |
+
'max_position_embeddings' : config["model"]["max_position_embeddings"],
|
49 |
+
'max_seq_position_embeddings' : config["model"]["max_seq_position_embeddings"],
|
50 |
+
'add_position_ids' : config["model"]["add_position_ids"]
|
51 |
+
}
|
52 |
+
else:
|
53 |
+
position_embeddings = config["model"]["position_embeddings"]
|
54 |
+
assert position_embeddings in ["none", "abs_1d", "abs_2d", "rot_1d", "rot_2d"]
|
55 |
+
if position_embeddings != "none":
|
56 |
+
position_embeddings = position_embeddings.split("_")[-1]
|
57 |
+
pe_kwargs = {
|
58 |
+
'max_position_embeddings' : config["model"]["max_position_embeddings"],
|
59 |
+
'max_seq_position_embeddings' : config["model"]["max_seq_position_embeddings"],
|
60 |
+
'add_position_ids' : position_embeddings
|
61 |
+
}
|
62 |
+
|
63 |
+
# Setup WandB
|
64 |
+
wandb_run_name = setup_wandb(config)
|
65 |
+
|
66 |
+
# Load datasets
|
67 |
+
dataset_params = {
|
68 |
+
"msa_memmap_path": config["msa_memmap_path"],
|
69 |
+
"msa_memmap_meta_path": config["msa_memmap_meta_path"],
|
70 |
+
"sample": config["sample_sequences"],
|
71 |
+
"max_msa_len": config["max_msa_len"],
|
72 |
+
"reverse": False,
|
73 |
+
"seed": config["seed_sequence_sampling"],
|
74 |
+
"troubleshoot": False,
|
75 |
+
"fim_strategy": config["fim_strategy"],
|
76 |
+
"always_mask": config["always_mask"],
|
77 |
+
**pe_kwargs,
|
78 |
+
}
|
79 |
+
train_dataset = ProteinMemmapDataset(subset_path=config["train_set"], **dataset_params)
|
80 |
+
valid_dataset = ProteinMemmapDataset(subset_path=config["valid_set"], **dataset_params)
|
81 |
+
train_eval_dataset = ProteinMemmapDataset(subset_path=config["train_eval_set"], **dataset_params)
|
82 |
+
|
83 |
+
print(f'Train set size: {len(train_dataset)} Train eval set size: {len(train_eval_dataset)} Valid set size: {len(valid_dataset)}')
|
84 |
+
|
85 |
+
assert (
|
86 |
+
len(AA_TO_ID) == config["model"]["vocab_size"]
|
87 |
+
), f"Vocab size in the config file does not match the one in the code. I should be {len(AA_TO_ID)}"
|
88 |
+
|
89 |
+
# Create data collator for batched training
|
90 |
+
data_collator = ProteinDataCollator(max_sequence_length=config["max_msa_len"])
|
91 |
+
|
92 |
+
# Check datatypes
|
93 |
+
if config["dtype"] == "float32":
|
94 |
+
dtype = torch.float32
|
95 |
+
elif config["dtype"] == "bfloat16":
|
96 |
+
dtype = torch.bfloat16
|
97 |
+
else:
|
98 |
+
raise ValueError("dtype must be either float32 or bfloat16")
|
99 |
+
|
100 |
+
# Initialize model
|
101 |
+
if config.model_type == 'xlstm':
|
102 |
+
|
103 |
+
# Load model for finetuning
|
104 |
+
if config.finetune_model_path:
|
105 |
+
# These fields are updated in the config loaded from the checkpoint
|
106 |
+
config_update_kwargs = {
|
107 |
+
"mlstm_backend": config["model"]["mlstm_block"]["mlstm"]["backend"],
|
108 |
+
"mlstm_chunksize": config["model"]["mlstm_block"]["mlstm"]["chunk_size"],
|
109 |
+
"checkpoint_blocks": config["model"]["checkpoint_blocks"],
|
110 |
+
"rope_base_frequency": config["model"]["rope_base_frequency"]
|
111 |
+
}
|
112 |
+
model = load_model(
|
113 |
+
config.finetune_model_path,
|
114 |
+
model_class=xLSTMLMHeadModel,
|
115 |
+
device="cuda",
|
116 |
+
dtype=dtype,
|
117 |
+
**config_update_kwargs
|
118 |
+
)
|
119 |
+
else:
|
120 |
+
# Create new mode
|
121 |
+
xlstm_config = xLSTMConfig().init_from_dict(config["model"])
|
122 |
+
model = xLSTMLMHeadModel(xlstm_config)
|
123 |
+
|
124 |
+
elif config.model_type == 'mamba':
|
125 |
+
|
126 |
+
_mamba_model = {
|
127 |
+
"none": MambaLMHeadModelSafe,
|
128 |
+
"1d": MambaLMHeadModelwithPosids,
|
129 |
+
"2d": MambaLMHeadModelwith2DPosids,
|
130 |
+
}
|
131 |
+
Mamba = _mamba_model[config['model']["add_position_ids"]]
|
132 |
+
|
133 |
+
# Load model for finetuning
|
134 |
+
if config.finetune_model_path:
|
135 |
+
model = load_model(
|
136 |
+
config.finetune_model_path,
|
137 |
+
model_class=Mamba,
|
138 |
+
device="cuda",
|
139 |
+
dtype=dtype,
|
140 |
+
checkpoint_mixer=config["checkpoint_mixer"],
|
141 |
+
)
|
142 |
+
else:
|
143 |
+
# Create new mode
|
144 |
+
mamba_config = MambaConfig(d_model=config['model']["d_model"],
|
145 |
+
n_layer=config['model']["n_layer"],
|
146 |
+
vocab_size=config['model']["vocab_size"],
|
147 |
+
residual_in_fp32=config['model']["residual_in_fp32"])
|
148 |
+
model = Mamba(mamba_config, dtype=dtype, checkpoint_mixer=config['model']["checkpoint_mixer"])
|
149 |
+
|
150 |
+
elif config.model_type == 'llama':
|
151 |
+
|
152 |
+
llama_config = TransformerConfig(
|
153 |
+
d_model=config["model"]["d_model"],
|
154 |
+
n_layer=config["model"]["n_layer"],
|
155 |
+
n_heads=config["model"]["n_heads"],
|
156 |
+
n_kv_heads=config["model"]["n_kv_heads"],
|
157 |
+
bidirectional=config["model"]["bidirectional"],
|
158 |
+
hidden_dim=config["model"]["hidden_dim"],
|
159 |
+
multiple_of=config["model"]["multiple_of"],
|
160 |
+
norm_eps=config["model"]["norm_eps"],
|
161 |
+
max_length=config["model"]["max_length"],
|
162 |
+
vocab_size=config["model"]["vocab_size"],
|
163 |
+
dropout=config["model"]["dropout"],
|
164 |
+
max_position_embeddings=config["model"]["max_position_embeddings"],
|
165 |
+
rope_base_frequency=config["model"]["rope_base_frequency"],
|
166 |
+
|
167 |
+
)
|
168 |
+
|
169 |
+
model = TransformerLMHeadModel(llama_config)
|
170 |
+
|
171 |
+
else:
|
172 |
+
raise ValueError(f"Unsupported model_type: {config.model_type}. Expected 'xlstm', 'mamba', or 'llama'.")
|
173 |
+
|
174 |
+
|
175 |
+
# TODO: Improve what we want print
|
176 |
+
if is_zero_rank():
|
177 |
+
print_number_of_parameters(model)
|
178 |
+
print_zero_rank(f"dtype: {config['dtype']}")
|
179 |
+
print_zero_rank(f"Epochs: {config['num_epochs']}")
|
180 |
+
print_zero_rank(f"Batch size per GPU: {config['batch_size']}")
|
181 |
+
print_zero_rank(f"Gradient accumulation steps: {config['gradient_accumulation_steps']}")
|
182 |
+
eff_batch_size = config["batch_size"] * config["gradient_accumulation_steps"]
|
183 |
+
nr_gpus = torch.cuda.device_count()
|
184 |
+
print_zero_rank(f"GPUS: {nr_gpus}")
|
185 |
+
eff_batch_size *= nr_gpus
|
186 |
+
print_zero_rank(f"Effective batch size: {eff_batch_size}")
|
187 |
+
print_zero_rank(
|
188 |
+
f"Steps per training epoch: {len(train_dataset) // config['batch_size']}, eff. steps: {len(train_dataset) // eff_batch_size}"
|
189 |
+
)
|
190 |
+
print_zero_rank(f"Steps per evaluation epoch: {len(valid_dataset) // config['batch_size']}")
|
191 |
+
print_zero_rank(f"Max MSA length: {config['max_msa_len']}")
|
192 |
+
ev_epochs = round(
|
193 |
+
config["eval_steps"] * config["batch_size"] / len(train_dataset), 3
|
194 |
+
)
|
195 |
+
print_zero_rank(
|
196 |
+
f"Evaluation every {config['eval_steps']} steps, i.e. {ev_epochs} epochs. Effectively every {config['eval_steps']*config['gradient_accumulation_steps']} steps, i.e. {ev_epochs*config['gradient_accumulation_steps']} epochs."
|
197 |
+
)
|
198 |
+
if config.model_type == 'xlstm' and config["model"]["checkpoint_blocks"]:
|
199 |
+
print_zero_rank("Using gradient checkpointing")
|
200 |
+
if config["compute_only_fim_loss"]:
|
201 |
+
print_zero_rank("Computing only FIM loss for training")
|
202 |
+
|
203 |
+
# Training callbacks
|
204 |
+
es_callback = EarlyStoppingCallback(
|
205 |
+
train_path=config["output_dir"] + '/' + wandb_run_name, config=config
|
206 |
+
)
|
207 |
+
callbacks = [es_callback]
|
208 |
+
|
209 |
+
# Optimizer and Schedulers
|
210 |
+
optimizer, scheduler = set_optimizer_and_scheduler(
|
211 |
+
config,
|
212 |
+
len(train_dataset),
|
213 |
+
model.parameters()
|
214 |
+
)
|
215 |
+
|
216 |
+
# Find checkpoint if available
|
217 |
+
last_checkpoint = None
|
218 |
+
if config.finetune_model_path is None:
|
219 |
+
path = os.path.join(config["output_dir"], wandb_run_name)
|
220 |
+
if os.path.exists(path):
|
221 |
+
last_checkpoint = get_last_checkpoint(path)
|
222 |
+
if last_checkpoint is None:
|
223 |
+
print_zero_rank("No checkpoint found, starting training from scratch.")
|
224 |
+
else:
|
225 |
+
print_zero_rank(f"Resuming training from the last checkpoint: {last_checkpoint}")
|
226 |
+
|
227 |
+
# Create trainer
|
228 |
+
trainer = ProtTrainer(
|
229 |
+
model=model,
|
230 |
+
train_dataset=train_dataset,
|
231 |
+
eval_dataset={"valid": valid_dataset, "train": train_eval_dataset},
|
232 |
+
optimizers=(optimizer, scheduler),
|
233 |
+
args=TrainingArguments(
|
234 |
+
run_name=wandb_run_name,
|
235 |
+
local_rank=int(os.getenv('LOCAL_RANK', '0')),
|
236 |
+
learning_rate=config["learning_rate"],
|
237 |
+
num_train_epochs=config["num_epochs"],
|
238 |
+
per_device_train_batch_size=config["batch_size"],
|
239 |
+
per_device_eval_batch_size=config["batch_size"],
|
240 |
+
gradient_accumulation_steps=config["gradient_accumulation_steps"],
|
241 |
+
eval_accumulation_steps=config["eval_accumulation_steps"],
|
242 |
+
eval_strategy="steps",
|
243 |
+
max_grad_norm=config["max_grad_norm"],
|
244 |
+
bf16=config["dtype"] == "bfloat16",
|
245 |
+
dataloader_num_workers=32,
|
246 |
+
logging_steps=config["logging_steps"],
|
247 |
+
eval_steps=config["eval_steps"],
|
248 |
+
save_steps=config["save_steps"],
|
249 |
+
output_dir=config["output_dir"] + '/' + wandb_run_name,
|
250 |
+
logging_dir=config["output_dir"] + '/' + wandb_run_name,
|
251 |
+
report_to="wandb" if is_zero_rank() else None,
|
252 |
+
log_on_each_node=False,
|
253 |
+
overwrite_output_dir=False,
|
254 |
+
push_to_hub=False,
|
255 |
+
label_names=["labels"],
|
256 |
+
),
|
257 |
+
compute_only_fim_loss=config["compute_only_fim_loss"],
|
258 |
+
data_collator=data_collator,
|
259 |
+
compute_metrics=compute_metrics,
|
260 |
+
callbacks=callbacks,
|
261 |
+
)
|
262 |
+
|
263 |
+
# Train model
|
264 |
+
while True:
|
265 |
+
if last_checkpoint is None and trainer.state.global_step == 0:
|
266 |
+
eval_results = trainer.evaluate()
|
267 |
+
print_zero_rank(
|
268 |
+
f">>> Initial validation perplexity: {eval_results['eval_valid_perplexity/batch']:.2f}"
|
269 |
+
)
|
270 |
+
else:
|
271 |
+
print_zero_rank(f"Resuming training from the last checkpoint: {last_checkpoint}")
|
272 |
+
# Train
|
273 |
+
trainer.train(resume_from_checkpoint=last_checkpoint)
|
274 |
+
|
275 |
+
# Break training when the number of epochs is reached
|
276 |
+
if (
|
277 |
+
not es_callback.should_restart
|
278 |
+
or trainer.state.epoch >= config["num_epochs"]
|
279 |
+
):
|
280 |
+
eval_results = trainer.evaluate()
|
281 |
+
print_zero_rank(
|
282 |
+
f">>> Final Perplexity: {eval_results['eval_valid_perplexity/batch']:.2f}"
|
283 |
+
)
|
284 |
+
break
|
285 |
+
# If the training was interrupted because of a loss spike, restart from the last checkpoint
|
286 |
+
last_checkpoint = es_callback.checkpoint_path
|
287 |
+
|
288 |
+
return trainer
|
289 |
+
|
290 |
+
if __name__ == "__main__":
|
291 |
+
|
292 |
+
# Default configuration file paths
|
293 |
+
default_model_config = "configs/xlstm_default_config.yaml"
|
294 |
+
default_train_config = "configs/train_default_config.yaml"
|
295 |
+
|
296 |
+
parser = argparse.ArgumentParser(
|
297 |
+
description="Train or finetune a model with the provided configuration."
|
298 |
+
)
|
299 |
+
parser.add_argument(
|
300 |
+
"--model_config_path",
|
301 |
+
type=str,
|
302 |
+
default=default_model_config,
|
303 |
+
help=f"Path to the model configuration file (default: {default_model_config})"
|
304 |
+
)
|
305 |
+
parser.add_argument(
|
306 |
+
"--train_config_path",
|
307 |
+
type=str,
|
308 |
+
default=default_train_config,
|
309 |
+
help=f"Path to the training and dataset configuration file (default: {default_train_config})"
|
310 |
+
)
|
311 |
+
parser.add_argument(
|
312 |
+
"overrides",
|
313 |
+
nargs=argparse.REMAINDER,
|
314 |
+
help="Override configuration values using key=value format.",
|
315 |
+
)
|
316 |
+
|
317 |
+
args = parser.parse_args()
|
318 |
+
|
319 |
+
# Check if the default config files exist, or raise an error
|
320 |
+
if not os.path.exists(args.model_config_path):
|
321 |
+
raise FileNotFoundError(f"Model config file not found: {args.model_config_path}")
|
322 |
+
if not os.path.exists(args.train_config_path):
|
323 |
+
raise FileNotFoundError(f"Train config file not found: {args.train_config_path}")
|
324 |
+
|
325 |
+
# Load the model and training configurations
|
326 |
+
model_config = OmegaConf.load(args.model_config_path)
|
327 |
+
train_config = OmegaConf.load(args.train_config_path)
|
328 |
+
|
329 |
+
# Merge the model and training configurations
|
330 |
+
config = OmegaConf.merge(model_config, train_config)
|
331 |
+
|
332 |
+
# Parse overrides
|
333 |
+
if args.overrides:
|
334 |
+
overrides = parse_override_args(args.overrides)
|
335 |
+
config.merge_with(OmegaConf.create(overrides))
|
336 |
+
|
337 |
+
# Run the training/finetuning process
|
338 |
+
run(config)
|
protxlstm/trainer.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Original code from ProtMamba under Apache License 2.0.
|
2 |
+
#
|
3 |
+
# Modifications made by Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen
|
4 |
+
# - MambaTrainer renamed to ProtTrainer
|
5 |
+
|
6 |
+
import os
|
7 |
+
import re
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from transformers import Trainer, TrainerCallback
|
11 |
+
|
12 |
+
from protxlstm.utils import AA_TO_ID, find_fim_indices
|
13 |
+
|
14 |
+
class ProtTrainer(Trainer):
|
15 |
+
"""
|
16 |
+
Base HuggingFace Trainer used for training.
|
17 |
+
|
18 |
+
from https://github.com/havenhq/mamba-chat/blob/main/trainer/mamba_trainer.py"""
|
19 |
+
def __init__(self, compute_only_fim_loss, **kwargs,):
|
20 |
+
super().__init__(**kwargs)
|
21 |
+
self.compute_only_fim_loss = compute_only_fim_loss
|
22 |
+
|
23 |
+
|
24 |
+
def compute_loss(self, model, inputs, return_outputs=False):
|
25 |
+
input_ids = inputs.pop("input_ids")
|
26 |
+
labels = inputs.pop("labels")
|
27 |
+
if "seq_position_ids" in inputs and "position_ids" in inputs:
|
28 |
+
position_ids = inputs.pop("position_ids")
|
29 |
+
seq_position_ids = inputs.pop("seq_position_ids")
|
30 |
+
output = model(input_ids, position_ids=position_ids, seq_position_ids=seq_position_ids)
|
31 |
+
elif "position_ids" in inputs:
|
32 |
+
position_ids = inputs.pop("position_ids")
|
33 |
+
output = model(input_ids, position_ids=position_ids)
|
34 |
+
else:
|
35 |
+
output = model(input_ids)
|
36 |
+
lm_logits = output.logits
|
37 |
+
|
38 |
+
labels = labels.to(lm_logits.device)
|
39 |
+
shift_logits = lm_logits[:, :-1, :].contiguous()
|
40 |
+
labels = labels[:, 1:].contiguous()
|
41 |
+
|
42 |
+
loss_fct = torch.nn.CrossEntropyLoss()
|
43 |
+
if self.compute_only_fim_loss:
|
44 |
+
# start and end tokens
|
45 |
+
is_cls_tokens = (labels == AA_TO_ID["<cls>"])
|
46 |
+
is_eos_tokens = (labels == AA_TO_ID["<eos>"])
|
47 |
+
bool_fim = find_fim_indices(is_cls_tokens, is_eos_tokens)
|
48 |
+
# include also the cls token
|
49 |
+
bool_fim = bool_fim | is_cls_tokens
|
50 |
+
inds = torch.where(bool_fim)
|
51 |
+
lm_loss = loss_fct(shift_logits[inds[0], inds[1], :], labels[bool_fim])
|
52 |
+
else:
|
53 |
+
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
|
54 |
+
|
55 |
+
return (lm_loss, output) if return_outputs else lm_loss
|
56 |
+
|
57 |
+
def save_model(self, output_dir, _internal_call):
|
58 |
+
if int(os.getenv('LOCAL_RANK', '0')) == 0:
|
59 |
+
self.model.save_pretrained(output_dir)
|
60 |
+
|
61 |
+
PREFIX_CHECKPOINT_DIR = "checkpoint"
|
62 |
+
_re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$")
|
63 |
+
|
64 |
+
def get_last_checkpoint(folder, max_steps=None):
|
65 |
+
content = os.listdir(folder)
|
66 |
+
checkpoints = [
|
67 |
+
path
|
68 |
+
for path in content
|
69 |
+
if _re_checkpoint.search(path) is not None and os.path.isdir(os.path.join(folder, path))
|
70 |
+
]
|
71 |
+
if len(checkpoints) == 0:
|
72 |
+
return
|
73 |
+
|
74 |
+
max_steps = max_steps if max_steps is not None else float("inf")
|
75 |
+
# func = lambda x: int(_re_checkpoint.search(x).groups()[0])
|
76 |
+
def func(x):
|
77 |
+
num = int(_re_checkpoint.search(x).groups()[0])
|
78 |
+
return num if num < max_steps else -1
|
79 |
+
return os.path.join(folder, max(checkpoints, key=func))
|
80 |
+
|
81 |
+
class EarlyStoppingCallback(TrainerCallback):
|
82 |
+
def __init__(self, train_path, config=None):
|
83 |
+
self.step_counter_reset = 0
|
84 |
+
self.step_counter_stop = 0
|
85 |
+
self.best_loss = None
|
86 |
+
self.train_path = train_path
|
87 |
+
self.patience = config["patience"]
|
88 |
+
self.metric_name = config["early_stopping_metric"]
|
89 |
+
self.checkpoint_path = None
|
90 |
+
self.should_restart = False
|
91 |
+
self.eval_steps = config["eval_steps"]
|
92 |
+
self.loss_increase_factor = config["loss_increase_factor"]
|
93 |
+
|
94 |
+
def get_checkpoint_path(self, max_steps):
|
95 |
+
last_checkpoint = None
|
96 |
+
if os.path.exists(self.train_path):
|
97 |
+
last_checkpoint = get_last_checkpoint(self.train_path, max_steps)
|
98 |
+
if last_checkpoint is None:
|
99 |
+
print("No checkpoint found, starting training from scratch.")
|
100 |
+
else:
|
101 |
+
print(f"Max checkpoint allowed: {max_steps}, restarting from {last_checkpoint}.")
|
102 |
+
return last_checkpoint
|
103 |
+
|
104 |
+
def on_evaluate(self, args, state, control, model, metrics, **kwargs):
|
105 |
+
if self.metric_name in metrics:
|
106 |
+
if self.best_loss is None:
|
107 |
+
self.best_loss = metrics[self.metric_name]
|
108 |
+
elif self.best_loss*self.loss_increase_factor < metrics[self.metric_name]:
|
109 |
+
self.step_counter += 1
|
110 |
+
if self.step_counter >= self.patience:
|
111 |
+
checkpoint_path = self.get_checkpoint_path(max_steps=(state.global_step-self.patience*self.eval_steps))
|
112 |
+
control.should_training_stop = True
|
113 |
+
self.checkpoint_path = checkpoint_path
|
114 |
+
self.should_restart = True
|
115 |
+
else:
|
116 |
+
self.step_counter = 0
|
117 |
+
self.best_loss = min(self.best_loss, metrics[self.metric_name])
|
118 |
+
self.should_restart = False
|
119 |
+
|
120 |
+
def on_train_begin(self, args, state, control, **kwargs):
|
121 |
+
self.step_counter = 0
|
122 |
+
self.best_loss = None
|
123 |
+
self.should_restart = False
|
protxlstm/utils.py
ADDED
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Some of the objects in this file come from ProtMamba and mamba both under Apache License 2.0.
|
2 |
+
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import rich
|
8 |
+
import torch
|
9 |
+
from Bio import SeqIO
|
10 |
+
from omegaconf import DictConfig, OmegaConf
|
11 |
+
from torch.optim import AdamW
|
12 |
+
import wandb
|
13 |
+
from transformers import (
|
14 |
+
get_constant_schedule_with_warmup,
|
15 |
+
get_cosine_schedule_with_warmup,
|
16 |
+
get_cosine_with_hard_restarts_schedule_with_warmup,
|
17 |
+
)
|
18 |
+
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
|
19 |
+
from transformers.utils.hub import cached_file
|
20 |
+
|
21 |
+
__all__ = ['AA_TO_ID', 'MASK_TO_ID', 'ID_TO_AA', 'load_model', 'encode_sequence', 'decode_sequence', 'clean_sequence', 'tokenizer',
|
22 |
+
'reorder_masked_sequence', 'load_sequences_from_msa_file', 'prepare_dataset_for_fim_generation',
|
23 |
+
'prepare_tokens', 'prepare_target', 'print_number_of_parameters', 'find_fim_indices',
|
24 |
+
'compute_metrics', 'compute_metrics_with_std', 'print_config', 'print_zero_rank', 'is_zero_rank']
|
25 |
+
|
26 |
+
# Constants
|
27 |
+
AA_TO_ID = {'<cls>': 0,
|
28 |
+
'<pad>': 1,
|
29 |
+
'<eos>': 2,
|
30 |
+
'<unk>': 3,
|
31 |
+
'L': 4,
|
32 |
+
'A': 5,
|
33 |
+
'G': 6,
|
34 |
+
'V': 7,
|
35 |
+
'S': 8,
|
36 |
+
'E': 9,
|
37 |
+
'R': 10,
|
38 |
+
'T': 11,
|
39 |
+
'I': 12,
|
40 |
+
'D': 13,
|
41 |
+
'P': 14,
|
42 |
+
'K': 15,
|
43 |
+
'Q': 16,
|
44 |
+
'N': 17,
|
45 |
+
'F': 18,
|
46 |
+
'Y': 19,
|
47 |
+
'M': 20,
|
48 |
+
'H': 21,
|
49 |
+
'W': 22,
|
50 |
+
'C': 23,
|
51 |
+
'X': 24,
|
52 |
+
'B': 25,
|
53 |
+
'U': 26,
|
54 |
+
'Z': 27,
|
55 |
+
'O': 28,
|
56 |
+
'.': 29,
|
57 |
+
'-': 30,
|
58 |
+
'<null_1>': 31,
|
59 |
+
'<mask>': 32}
|
60 |
+
|
61 |
+
MASK_TO_ID = {"<mask-1>": 33,
|
62 |
+
"<mask-2>": 34,
|
63 |
+
"<mask-3>": 35,
|
64 |
+
"<mask-4>": 36,
|
65 |
+
"<mask-5>": 37,}
|
66 |
+
|
67 |
+
AA_TO_ID.update(MASK_TO_ID)
|
68 |
+
|
69 |
+
ID_TO_AA = {v: k for k, v in AA_TO_ID.items()}
|
70 |
+
|
71 |
+
# Logging & prints
|
72 |
+
def setup_wandb(config):
|
73 |
+
|
74 |
+
# WandB setup
|
75 |
+
os.environ["WANDB_PROJECT"] = config["wandb_project"]
|
76 |
+
os.environ["WANDB_ENTITY"] = config["wandb_entity"]
|
77 |
+
os.environ["WANDB_MODE"] = config["wandb_mode"]
|
78 |
+
|
79 |
+
if config['model_type'] == 'xlstm':
|
80 |
+
pe = config['model']['add_position_ids']
|
81 |
+
pe = 'None' if pe == 'none' else 'AbsPE' if pe == 'abs_1d' else 'AbsPE2' if pe == 'abs_2d' else 'RoPE' if pe == 'rot_1d' else pe == 'rot_2d'
|
82 |
+
wandb_run_name = f"{config['model_type']}_l{config['model']['num_blocks']}_d{config['model']['embedding_dim']}_{pe}_s{config['max_msa_len']}_lr{config['learning_rate']}"
|
83 |
+
elif config['model_type'] == 'mamba':
|
84 |
+
pe = config['model']['add_position_ids']
|
85 |
+
pe = 'None' if pe == 'none' else 'AbsPE' if pe == '1d' else pe == '2d'
|
86 |
+
wandb_run_name = f"{config['model_type']}_l{config['model']['n_layer']}_d{config['model']['d_model']}_{pe}_s{config['max_msa_len']}_lr{config['learning_rate']}"
|
87 |
+
elif config['model_type'] == 'llama':
|
88 |
+
pe = 'RoPE'
|
89 |
+
wandb_run_name = f"{config['model_type']}_l{config['model']['n_layer']}_d{config['model']['d_model']}_dh{config['model']['hidden_dim']}_{prepare_dataset_for_fim_generation}_s{config['max_msa_len']}_lr{config['learning_rate']}_sched-{config['scheduler']}"
|
90 |
+
|
91 |
+
if config['name_prefix']:
|
92 |
+
wandb_run_name = str(config['name_prefix']) + '_' + wandb_run_name
|
93 |
+
if config['name_suffix']:
|
94 |
+
wandb_run_name = wandb_run_name + '_' + str(config['name_suffix'])
|
95 |
+
|
96 |
+
if is_zero_rank():
|
97 |
+
wandb.init(
|
98 |
+
project=config["wandb_project"],
|
99 |
+
entity=config["wandb_entity"],
|
100 |
+
mode=config["wandb_mode"],
|
101 |
+
name=wandb_run_name)
|
102 |
+
config_dict = OmegaConf.to_container(config, resolve=True)
|
103 |
+
wandb.config.update(config_dict)
|
104 |
+
return wandb_run_name
|
105 |
+
|
106 |
+
def is_zero_rank():
|
107 |
+
return int(os.getenv('LOCAL_RANK', '0')) == 0
|
108 |
+
|
109 |
+
def print_zero_rank(var):
|
110 |
+
if is_zero_rank():
|
111 |
+
print(var)
|
112 |
+
|
113 |
+
def print_number_of_parameters(model):
|
114 |
+
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
115 |
+
formatted_num_params = f"{num_params:_}"
|
116 |
+
print("Number of trainable parameters: ", formatted_num_params)
|
117 |
+
|
118 |
+
# Sequence tools
|
119 |
+
def encode_sequence(sequence):
|
120 |
+
"""Tokenize a sequence of amino acids and add a cls token at the beginning."""
|
121 |
+
tokenized_sequence = [AA_TO_ID[aa] if aa in AA_TO_ID else AA_TO_ID['<unk>'] for aa in sequence]
|
122 |
+
return [AA_TO_ID['<cls>']] + tokenized_sequence
|
123 |
+
|
124 |
+
def decode_sequence(sequence):
|
125 |
+
"""Decode a sequence of tokens."""
|
126 |
+
return "".join([ID_TO_AA[token] if token in ID_TO_AA else "<unk>" for token in sequence])
|
127 |
+
|
128 |
+
def clean_sequence(sequence):
|
129 |
+
"""Remove gaps and convert all residues to upper case."""
|
130 |
+
return sequence.replace("-", "").upper()
|
131 |
+
|
132 |
+
def tokenizer(sequence_list, concatenate=True):
|
133 |
+
"""Tokenize a collection of sequences. If the sequences are aligned, the gaps will be removed
|
134 |
+
and the insertions (lower case) will be promoted to upper case."""
|
135 |
+
# clean and encode all sequences
|
136 |
+
sequence_list = [encode_sequence(clean_sequence(sequence)) for sequence in sequence_list]
|
137 |
+
if concatenate:
|
138 |
+
# concatenate all sequences
|
139 |
+
sequences = np.concatenate(sequence_list)
|
140 |
+
# convert to tensor and add batch dimension
|
141 |
+
return torch.asarray(sequences, dtype=torch.int8)[None,:]
|
142 |
+
else:
|
143 |
+
return [torch.asarray(sequence, dtype=torch.int8) for sequence in sequence_list]
|
144 |
+
|
145 |
+
def reorder_masked_sequence(mask_seq, return_ids=False):
|
146 |
+
"""
|
147 |
+
Reorder a masked sequence to fill the masked positions with the tokens
|
148 |
+
that should be there but are positioned after the <eos> token.
|
149 |
+
"""
|
150 |
+
mask_seq = mask_seq.split("<cls>")[0]
|
151 |
+
try:
|
152 |
+
# Split the sequence and masks
|
153 |
+
seq, masks = mask_seq.split("<eos>")
|
154 |
+
except:
|
155 |
+
return mask_seq
|
156 |
+
full_seq = ""
|
157 |
+
ids_mask = []
|
158 |
+
# Iterate over each mask tag
|
159 |
+
for mm in ["<mask-1>", "<mask-2>", "<mask-3>", "<mask-4>", "<mask-5>","<mask-?>"]:
|
160 |
+
try:
|
161 |
+
# Split the sequence in before and after the mask tag
|
162 |
+
seq1, seq2 = seq.split(mm)
|
163 |
+
if mm=="<mask-1>":
|
164 |
+
# If the mask is the first one, add the sequence before the mask and update the masks
|
165 |
+
masks = masks.split("<mask-1>")[1]
|
166 |
+
full_seq += seq1
|
167 |
+
else:
|
168 |
+
# If the mask is not the first one, insert the mask between the two sequence parts
|
169 |
+
masks1, masks2 = masks.split(mm)
|
170 |
+
ids_mask += [(len(full_seq), len(full_seq)+len(masks1))]
|
171 |
+
full_seq += masks1 + seq1
|
172 |
+
# Update the masks
|
173 |
+
masks = masks2
|
174 |
+
# Update the sequence with the part after the mask
|
175 |
+
seq = seq2
|
176 |
+
except:
|
177 |
+
# If the mask is not found, add the remaining sequence
|
178 |
+
ids_mask += [(len(full_seq), len(full_seq)+len(masks))]
|
179 |
+
full_seq += masks + seq
|
180 |
+
break
|
181 |
+
if return_ids:
|
182 |
+
return full_seq, ids_mask
|
183 |
+
return full_seq
|
184 |
+
|
185 |
+
def load_sequences_from_msa_file(file_path):
|
186 |
+
"""Load a collection of sequences from an a3m file."""
|
187 |
+
with open(file_path, "r") as f:
|
188 |
+
sequences = [str(record.seq) for record in SeqIO.parse(f, "fasta")]
|
189 |
+
return sequences
|
190 |
+
|
191 |
+
def prepare_dataset_for_fim_generation(tokens, pos_ids):
|
192 |
+
"""
|
193 |
+
Function to transform the tokenized training dataset into a format that can be used for FIM generation.
|
194 |
+
Splits the input tokens and pos_ids into the FIM part (of the last sequence) and the context part (all
|
195 |
+
the previous sequences and the masked part of the last sequence).
|
196 |
+
Also returns a dictionary with the positions of the mask tokens in the FIM part.
|
197 |
+
"""
|
198 |
+
def find_mask_positions(tokens_fim):
|
199 |
+
"""
|
200 |
+
Function to find the positions of the mask tokens in the FIM part of the last sequence.
|
201 |
+
"""
|
202 |
+
bool_mask = None
|
203 |
+
inds_masks = []
|
204 |
+
for ind in MASK_TO_ID.values():
|
205 |
+
tmp_bool = tokens_fim[0].cpu().numpy() == ind
|
206 |
+
bool_mask = tmp_bool if bool_mask is None else bool_mask | tmp_bool
|
207 |
+
inds_masks += [ind]
|
208 |
+
return bool_mask, inds_masks
|
209 |
+
# find where the FIM part of the last sequence starts
|
210 |
+
start_last_fim = np.where(tokens[0].cpu().numpy() == AA_TO_ID["<eos>"])[0][-1]
|
211 |
+
start_next_seqs = np.where(tokens[0,start_last_fim+1:].cpu().numpy() == AA_TO_ID["<cls>"])[0]
|
212 |
+
end_last_fim = start_last_fim+ 1 +start_next_seqs[0] if len(start_next_seqs) > 0 else tokens.shape[1]
|
213 |
+
# split tokens and pos_ids into FIM part and context part
|
214 |
+
tokens_to_fim = tokens[:,:start_last_fim+1]
|
215 |
+
pos_ids_to_fim = pos_ids[:,:start_last_fim+1]
|
216 |
+
tokens_fim = tokens[:,start_last_fim+1:end_last_fim]
|
217 |
+
pos_ids_fim = pos_ids[:,start_last_fim+1:end_last_fim]
|
218 |
+
# find positions of mask tokens
|
219 |
+
bool_mask, inds_masks = find_mask_positions(tokens_fim)
|
220 |
+
masked_positions = pos_ids_fim[0,bool_mask]
|
221 |
+
mask_dict = {ind: int(pos) for ind, pos in zip(inds_masks, masked_positions)}
|
222 |
+
return tokens_to_fim, pos_ids_to_fim, tokens_fim, pos_ids_fim, mask_dict
|
223 |
+
|
224 |
+
# Metrics
|
225 |
+
def find_fim_indices(is_cls_tokens, is_eos_tokens):
|
226 |
+
"""Function to find the indices of the FIM tokens in the sequences.
|
227 |
+
"""
|
228 |
+
# add a cls token at the beginning
|
229 |
+
is_cls_tokens = torch.cat([torch.ones_like(is_cls_tokens[:, :1]), is_cls_tokens], dim=1)
|
230 |
+
is_eos_tokens = torch.cat([torch.zeros_like(is_eos_tokens[:, :1]), is_eos_tokens], dim=1)
|
231 |
+
# both eos and cls tokens
|
232 |
+
bol = is_cls_tokens | is_eos_tokens
|
233 |
+
tmp = torch.zeros_like(is_cls_tokens, dtype=torch.int)
|
234 |
+
tmp[torch.nonzero(is_cls_tokens, as_tuple=True)] = 1
|
235 |
+
tmp[torch.nonzero(is_eos_tokens, as_tuple=True)] = -1
|
236 |
+
bol1 = torch.clone(bol)
|
237 |
+
for batch_ind in range(tmp.size(0)):
|
238 |
+
tmp1 = tmp[batch_ind,bol[batch_ind]]
|
239 |
+
# find all positions where a 1 if preceeded by a -1
|
240 |
+
tmp1 = tmp1[:-1]*tmp1[1:]
|
241 |
+
# add the first element to make the sequence start with a 1
|
242 |
+
tmp1 = torch.cat([torch.ones_like(tmp1[:1]).to(tmp1.device), tmp1])
|
243 |
+
new_bol = tmp1<0
|
244 |
+
# bool array True only in the positions where a 1 is preceeded by a -1
|
245 |
+
bol1[batch_ind,bol[batch_ind]] = False if new_bol.size(0) == 0 else new_bol
|
246 |
+
cumulative_sum = torch.cumsum(bol1, dim=1)
|
247 |
+
# Use modulo operation to get the desired tensor
|
248 |
+
bol2 = cumulative_sum % 2 == 1
|
249 |
+
bol2[is_eos_tokens]= False
|
250 |
+
return bol2[:,1:]
|
251 |
+
|
252 |
+
def compute_metrics(eval_pred):
|
253 |
+
predictions, labels = eval_pred
|
254 |
+
predictions = torch.tensor(predictions).permute(0, 2, 1)
|
255 |
+
labels = torch.tensor(labels)
|
256 |
+
# shift labels to align them with predictions and remove last prediction to match the length
|
257 |
+
predictions = predictions[:, :, :-1].contiguous()
|
258 |
+
labels = labels[:, 1:].contiguous()
|
259 |
+
# compute unreduced elementwise loss
|
260 |
+
unreduced_loss = torch.nn.functional.cross_entropy(predictions, labels, reduction="none")
|
261 |
+
# compute reconstruction accuracy
|
262 |
+
reconstruction = (predictions.argmax(1) == labels)
|
263 |
+
|
264 |
+
# start and end tokens
|
265 |
+
is_cls_tokens = (labels == AA_TO_ID["<cls>"])
|
266 |
+
is_eos_tokens = (labels == AA_TO_ID["<eos>"])
|
267 |
+
# fill in the middle tokens
|
268 |
+
if False:
|
269 |
+
fim_tokens = torch.zeros(is_cls_tokens.size(0), is_cls_tokens.size(1), dtype=torch.bool)
|
270 |
+
in_mask_vector = torch.zeros(is_cls_tokens.size(0), dtype=torch.bool)
|
271 |
+
for j in range(is_cls_tokens.size(1)):
|
272 |
+
in_mask_vector = in_mask_vector & ~is_cls_tokens[:, j]
|
273 |
+
fim_tokens[:, j] = in_mask_vector
|
274 |
+
in_mask_vector = in_mask_vector | is_eos_tokens[:, j]
|
275 |
+
fim_tokens = find_fim_indices(is_cls_tokens, is_eos_tokens)
|
276 |
+
|
277 |
+
number_sequences = torch.cumsum(torch.cat([torch.zeros(is_cls_tokens.size(0),1, dtype=torch.int32), is_cls_tokens[:,:-1]],1), -1)
|
278 |
+
# fist, second and last sequence tokens
|
279 |
+
first_sequence_tokens = ((~fim_tokens & (labels < 33)) | fim_tokens) & (number_sequences == 0)
|
280 |
+
second_sequence_tokens = ((~fim_tokens & (labels < 33)) | fim_tokens) & (number_sequences == 1)
|
281 |
+
last_sequence_tokens = ((~fim_tokens & (labels < 33)) | fim_tokens) & (number_sequences == (number_sequences.max(1).values[:, None] - 1))
|
282 |
+
# end of mask tokens
|
283 |
+
end_of_masks = (fim_tokens & (labels > 33)) | is_cls_tokens | is_eos_tokens
|
284 |
+
|
285 |
+
return {
|
286 |
+
"loss/all": torch.mean(unreduced_loss).item(),
|
287 |
+
"loss/end_span": torch.mean(unreduced_loss[end_of_masks]).item(),
|
288 |
+
"perplexity/seq": torch.mean(torch.exp(torch.mean(unreduced_loss, dim=1))).item(),
|
289 |
+
"perplexity/end_span": torch.exp(torch.mean(unreduced_loss[end_of_masks])).item(),
|
290 |
+
"perplexity/batch": torch.exp(torch.mean(unreduced_loss)).item(),
|
291 |
+
"perplexity/first_seq": torch.exp(torch.mean(unreduced_loss[first_sequence_tokens])).item(),
|
292 |
+
"perplexity/second_seq": torch.exp(torch.mean(unreduced_loss[second_sequence_tokens])).item(),
|
293 |
+
"perplexity/last_seq": torch.exp(torch.mean(unreduced_loss[last_sequence_tokens])).item(),
|
294 |
+
"perplexity/fim": torch.exp(torch.mean(unreduced_loss[fim_tokens])).item(),
|
295 |
+
"reconstruction/all": torch.mean(reconstruction.float()).item(),
|
296 |
+
"reconstruction/end_span": torch.mean(reconstruction[end_of_masks].float()).item(),
|
297 |
+
"reconstruction/first_seq": torch.mean(reconstruction[first_sequence_tokens].float()).item(),
|
298 |
+
"reconstruction/second_seq": torch.mean(reconstruction[second_sequence_tokens].float()).item(),
|
299 |
+
"reconstruction/last_seq": torch.mean(reconstruction[last_sequence_tokens].float()).item(),
|
300 |
+
"reconstruction/fim": torch.mean(reconstruction[fim_tokens].float()).item(),
|
301 |
+
}
|
302 |
+
|
303 |
+
def compute_metrics_with_std(eval_pred):
|
304 |
+
predictions, labels = eval_pred
|
305 |
+
predictions = torch.tensor(predictions).permute(0, 2, 1)
|
306 |
+
labels = torch.tensor(labels)
|
307 |
+
# shift labels to align them with predictions and remove last prediction to match the length
|
308 |
+
predictions = predictions[:, :, :-1].contiguous()
|
309 |
+
labels = labels[:, 1:].contiguous()
|
310 |
+
# compute unreduced elementwise loss
|
311 |
+
unreduced_loss = torch.nn.functional.cross_entropy(predictions, labels, reduction="none")
|
312 |
+
# compute reconstruction accuracy
|
313 |
+
reconstruction = (predictions.argmax(1) == labels)
|
314 |
+
|
315 |
+
# start and end tokens
|
316 |
+
is_cls_tokens = (labels == AA_TO_ID["<cls>"])
|
317 |
+
is_eos_tokens = (labels == AA_TO_ID["<eos>"])
|
318 |
+
# fill in the middle tokens
|
319 |
+
if False:
|
320 |
+
fim_tokens = torch.zeros(is_cls_tokens.size(0), is_cls_tokens.size(1), dtype=torch.bool)
|
321 |
+
in_mask_vector = torch.zeros(is_cls_tokens.size(0), dtype=torch.bool)
|
322 |
+
for j in range(is_cls_tokens.size(1)):
|
323 |
+
in_mask_vector = in_mask_vector & ~is_cls_tokens[:, j]
|
324 |
+
fim_tokens[:, j] = in_mask_vector
|
325 |
+
in_mask_vector = in_mask_vector | is_eos_tokens[:, j]
|
326 |
+
fim_tokens = find_fim_indices(is_cls_tokens, is_eos_tokens)
|
327 |
+
|
328 |
+
number_sequences = torch.cumsum(torch.cat([torch.zeros(is_cls_tokens.size(0),1, dtype=torch.int32), is_cls_tokens[:,:-1]],1), -1)
|
329 |
+
# fist, second and last sequence tokens
|
330 |
+
first_sequence_tokens = ((~fim_tokens & (labels < 33)) | fim_tokens) & (number_sequences == 0)
|
331 |
+
second_sequence_tokens = ((~fim_tokens & (labels < 33)) | fim_tokens) & (number_sequences == 1)
|
332 |
+
last_sequence_tokens = ((~fim_tokens & (labels < 33)) | fim_tokens) & (number_sequences == (number_sequences.max(1).values[:, None] - 1))
|
333 |
+
# end of mask tokens
|
334 |
+
end_of_masks = (fim_tokens & (labels > 33)) | is_cls_tokens | is_eos_tokens
|
335 |
+
|
336 |
+
def perplexities_per_seq_for_subset(unreduced_loss, subset):
|
337 |
+
return torch.exp(torch.nanmean(torch.where(subset, unreduced_loss, torch.tensor(float('nan'))), dim=1))
|
338 |
+
|
339 |
+
return{
|
340 |
+
# Loss
|
341 |
+
"loss/all": torch.mean(unreduced_loss).item(),
|
342 |
+
"loss/std": torch.std(unreduced_loss).item(),
|
343 |
+
"loss/end_span": torch.mean(unreduced_loss[end_of_masks]).item(),
|
344 |
+
"loss/end_span_std": torch.std(unreduced_loss[end_of_masks]).item(),
|
345 |
+
|
346 |
+
# Perplexity of all tokens
|
347 |
+
"perplexity/batch": torch.exp(torch.mean(unreduced_loss)).item(),
|
348 |
+
"perplexity/batch_std": torch.exp(torch.std(unreduced_loss)).item(), # Fix
|
349 |
+
|
350 |
+
# Perplexity per sequence
|
351 |
+
"perplexity/seq": torch.mean(torch.exp(torch.mean(unreduced_loss, dim=1))).item(),
|
352 |
+
"perplexity/seq_std": torch.std(torch.exp(torch.mean(unreduced_loss, dim=1))).item(),
|
353 |
+
"perplexity/end_span": torch.exp(torch.mean(unreduced_loss[end_of_masks])).item(),
|
354 |
+
"perplexity/end_span_std": torch.std(torch.exp(unreduced_loss[end_of_masks])).item(),
|
355 |
+
|
356 |
+
"perplexity/first_seq": torch.mean(perplexities_per_seq_for_subset(unreduced_loss, first_sequence_tokens)).item(),
|
357 |
+
"perplexity/first_seq_std": torch.std(perplexities_per_seq_for_subset(unreduced_loss, first_sequence_tokens)).item(),
|
358 |
+
"perplexity/second_seq": torch.mean(perplexities_per_seq_for_subset(unreduced_loss, second_sequence_tokens)).item(),
|
359 |
+
"perplexity/second_seq_std": torch.std(perplexities_per_seq_for_subset(unreduced_loss, second_sequence_tokens)).item(),
|
360 |
+
"perplexity/last_seq": torch.mean(perplexities_per_seq_for_subset(unreduced_loss, last_sequence_tokens)).item(),
|
361 |
+
"perplexity/last_seq_std": torch.std(perplexities_per_seq_for_subset(unreduced_loss, last_sequence_tokens)).item(),
|
362 |
+
"perplexity/fim": torch.mean(perplexities_per_seq_for_subset(unreduced_loss, fim_tokens)).item(),
|
363 |
+
"perplexity/fim_std": torch.std(perplexities_per_seq_for_subset(unreduced_loss, fim_tokens)).item(),
|
364 |
+
"reconstruction/all": torch.mean(reconstruction.float()).item(),
|
365 |
+
"reconstruction/std": torch.std(reconstruction.float()).item(),
|
366 |
+
"reconstruction/end_span": torch.mean(reconstruction[end_of_masks].float()).item(),
|
367 |
+
"reconstruction/end_span_std": torch.std(reconstruction[end_of_masks].float()).item(),
|
368 |
+
"reconstruction/first_seq": torch.mean(reconstruction[first_sequence_tokens].float()).item(),
|
369 |
+
"reconstruction/first_seq_std": torch.std(reconstruction[first_sequence_tokens].float()).item(),
|
370 |
+
"reconstruction/second_seq": torch.mean(reconstruction[second_sequence_tokens].float()).item(),
|
371 |
+
"reconstruction/second_seq_std": torch.std(reconstruction[second_sequence_tokens].float()).item(),
|
372 |
+
"reconstruction/last_seq": torch.mean(reconstruction[last_sequence_tokens].float()).item(),
|
373 |
+
"reconstruction/last_seq_std": torch.std(reconstruction[last_sequence_tokens].float()).item(),
|
374 |
+
"reconstruction/fim": torch.mean(reconstruction[fim_tokens].float()).item(),
|
375 |
+
"reconstruction/fim_std": torch.std(reconstruction[fim_tokens].float()).item(),
|
376 |
+
}
|
377 |
+
|
378 |
+
# Others
|
379 |
+
def set_optimizer_and_scheduler(config, ntrain, parameters):
|
380 |
+
|
381 |
+
# Set optimizer
|
382 |
+
optimizer = AdamW(
|
383 |
+
parameters,
|
384 |
+
lr=config["learning_rate"],
|
385 |
+
betas=(config["beta1"], config["beta2"]),
|
386 |
+
weight_decay=config["weight_decay"],
|
387 |
+
)
|
388 |
+
|
389 |
+
eff_batch_size = config["batch_size"] * config["gradient_accumulation_steps"] * torch.cuda.device_count()
|
390 |
+
|
391 |
+
# Set scheduler
|
392 |
+
if config["scheduler"] == "cosine":
|
393 |
+
print_zero_rank("Using cosine scheduler")
|
394 |
+
scheduler = get_cosine_schedule_with_warmup(
|
395 |
+
optimizer,
|
396 |
+
num_warmup_steps=config["warmup_steps"],
|
397 |
+
num_training_steps=config["num_epochs"] * ntrain // eff_batch_size,
|
398 |
+
)
|
399 |
+
if config["scheduler"] == "cosine-restarts":
|
400 |
+
scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
|
401 |
+
optimizer,
|
402 |
+
num_warmup_steps=config["warmup_steps"],
|
403 |
+
num_training_steps=config["num_epochs"] * ntrain // eff_batch_size,
|
404 |
+
num_cycles=config["num_cycles"],
|
405 |
+
)
|
406 |
+
elif config["scheduler"] == "constant":
|
407 |
+
print_zero_rank("Using constant scheduler with warmup")
|
408 |
+
scheduler = get_constant_schedule_with_warmup(
|
409 |
+
optimizer, num_warmup_steps=config["warmup_steps"]
|
410 |
+
)
|
411 |
+
else:
|
412 |
+
raise ValueError("Scheduler must be either cosine or constant")
|
413 |
+
|
414 |
+
# Finetuning and no optimizer/scheduler reset
|
415 |
+
if config.finetune_model_path and not config.restart_optimizer_and_scheduler:
|
416 |
+
optimizer.load_state_dict(torch.load(config.finetune_model_path + "/optimizer.pt"))
|
417 |
+
for param_group in optimizer.param_groups:
|
418 |
+
param_group['initial_lr'] = config['learning_rate']
|
419 |
+
param_group['lr'] = config['learning_rate']
|
420 |
+
|
421 |
+
scheduler.load_state_dict(torch.load(config.finetune_model_path + "/scheduler.pt"))
|
422 |
+
scheduler.base_lrs = [config['learning_rate']]
|
423 |
+
scheduler._last_lr = [config['learning_rate']]
|
424 |
+
|
425 |
+
return optimizer, scheduler
|
426 |
+
|
427 |
+
def parse_override_args(override_args):
|
428 |
+
overrides = {}
|
429 |
+
for arg in override_args:
|
430 |
+
key, value = arg.split("=")
|
431 |
+
keys = key.split(".")
|
432 |
+
sub_dict = overrides
|
433 |
+
for sub_key in keys[:-1]:
|
434 |
+
if sub_key not in sub_dict:
|
435 |
+
sub_dict[sub_key] = {}
|
436 |
+
sub_dict = sub_dict[sub_key]
|
437 |
+
# Convert value to appropriate type
|
438 |
+
if value == 'True':
|
439 |
+
value = True
|
440 |
+
elif value == 'False':
|
441 |
+
value = False
|
442 |
+
elif value == 'None':
|
443 |
+
value = None
|
444 |
+
else:
|
445 |
+
try:
|
446 |
+
value = int(value)
|
447 |
+
except ValueError:
|
448 |
+
try:
|
449 |
+
value = float(value)
|
450 |
+
except ValueError:
|
451 |
+
pass
|
452 |
+
sub_dict[keys[-1]] = value
|
453 |
+
return overrides
|
454 |
+
|
455 |
+
def load_model(
|
456 |
+
model_path,
|
457 |
+
device,
|
458 |
+
model_class,
|
459 |
+
dtype=torch.bfloat16,
|
460 |
+
**kwargs
|
461 |
+
):
|
462 |
+
model = model_class.from_pretrained(
|
463 |
+
model_path, device=device, dtype=dtype, **kwargs
|
464 |
+
)
|
465 |
+
return model
|
466 |
+
|
467 |
+
# https://github.com/state-spaces/mamba/blob/main/mamba_ssm/utils/hf.py
|
468 |
+
def load_config_hf(model_name):
|
469 |
+
resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
|
470 |
+
return json.load(open(resolved_archive_file))
|
471 |
+
|
472 |
+
# https://github.com/state-spaces/mamba/blob/main/mamba_ssm/utils/hf.py
|
473 |
+
def load_state_dict_hf(model_name, device=None, dtype=None):
|
474 |
+
# If not fp32, then we don't want to load directly to the GPU
|
475 |
+
mapped_device = "cpu" if dtype not in [torch.float32, None] else device
|
476 |
+
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
|
477 |
+
return torch.load(resolved_archive_file, map_location=mapped_device)
|
478 |
+
# Convert dtype before moving to GPU to save memory
|
479 |
+
if dtype is not None:
|
480 |
+
state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
|
481 |
+
state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
|
482 |
+
return state_dict
|
protxlstm/xlstm/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .blocks.mlstm.block import mLSTMBlock, mLSTMBlockConfig
|
2 |
+
from .blocks.mlstm.layer import mLSTMLayer, mLSTMLayerConfig
|
3 |
+
from .components.feedforward import FeedForwardConfig, GatedFeedForward
|
4 |
+
from .components.rotary_position import compute_freqs_cis, apply_rotary_emb
|
5 |
+
from .xlstm_block_stack import xLSTMBlockStack, xLSTMBlockStackConfig
|
6 |
+
from .xlstm_lm_model import xLSTMLMModel, xLSTMLMModelConfig
|
protxlstm/xlstm/blocks/__init__.py
ADDED
File without changes
|
protxlstm/xlstm/blocks/mlstm/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
protxlstm/xlstm/blocks/mlstm/backends.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) NXAI GmbH and its affiliates 2024
|
2 |
+
# Maximilian Beck
|
3 |
+
|
4 |
+
# Modified by Pieter-Jan Hoedt, Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen
|
5 |
+
# - Fix numerical issues between parallel and stepwise backends
|
6 |
+
# - Make chunkwise implementation compatible with stepwise backend and variable sequence lengths
|
7 |
+
|
8 |
+
|
9 |
+
import math
|
10 |
+
from typing import Union, Tuple, Optional
|
11 |
+
import torch
|
12 |
+
|
13 |
+
|
14 |
+
def parallel_stabilized_simple(
|
15 |
+
queries: torch.Tensor,
|
16 |
+
keys: torch.Tensor,
|
17 |
+
values: torch.Tensor,
|
18 |
+
igate_preact: torch.Tensor,
|
19 |
+
fgate_preact: torch.Tensor,
|
20 |
+
lower_triangular_matrix: torch.Tensor = None,
|
21 |
+
stabilize_rowwise: bool = True,
|
22 |
+
eps: float = 1e-6,
|
23 |
+
**kwargs,
|
24 |
+
) -> torch.Tensor:
|
25 |
+
"""This is the mLSTM cell in parallel form.
|
26 |
+
This version is stabilized. We control the range of exp() arguments by
|
27 |
+
ensuring that they are always smaller than 0.0 by subtracting the maximum.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
queries (torch.Tensor): (B, NH, S, DH)
|
31 |
+
keys (torch.Tensor): (B, NH, S, DH)
|
32 |
+
values (torch.Tensor): (B, NH, S, DH)
|
33 |
+
igate_preact (torch.Tensor): (B, NH, S, 1)
|
34 |
+
fgate_preact (torch.Tensor): (B, NH, S, 1)
|
35 |
+
lower_triangular_matrix (torch.Tensor, optional): (S,S). Defaults to None.
|
36 |
+
stabilize_rowwise (bool, optional): Wether to stabilize the combination matrix C rowwise (take maximum per row).
|
37 |
+
Alternative: Subtract the maximum over all rows. Defaults to True.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
torch.Tensor: (B, NH, S, DH), h_tilde_state
|
41 |
+
"""
|
42 |
+
|
43 |
+
B, NH, S, DH = queries.shape
|
44 |
+
_dtype, _device = queries.dtype, queries.device
|
45 |
+
|
46 |
+
# forget gate matrix
|
47 |
+
log_fgates = torch.nn.functional.logsigmoid(fgate_preact) # (B, NH, S, 1)
|
48 |
+
if lower_triangular_matrix is None or S < lower_triangular_matrix.size(-1):
|
49 |
+
ltr = torch.tril(torch.ones((S, S), dtype=torch.bool, device=_device))
|
50 |
+
else:
|
51 |
+
ltr = lower_triangular_matrix
|
52 |
+
assert (
|
53 |
+
ltr.dtype == torch.bool
|
54 |
+
), f"lower_triangular_matrix must be of dtype bool, got {ltr.dtype}"
|
55 |
+
|
56 |
+
log_f_mat = torch.tril(log_fgates.repeat(1, 1, 1, S), diagonal=-1)
|
57 |
+
log_prod_f_mat = torch.cumsum(log_f_mat, dim=-2)
|
58 |
+
# Causal masking & selection of the correct submatrix, such that forgetgate at timestep t is not applied
|
59 |
+
# to the input at timestep t
|
60 |
+
log_fg_matrix = torch.where(ltr, log_prod_f_mat, -float("inf")) # (B, NH, S, S)
|
61 |
+
|
62 |
+
# gate decay matrix D (combination of forget gate and input gate)
|
63 |
+
log_D_matrix = log_fg_matrix + igate_preact.transpose(-2, -1) # (B, NH, S, S)
|
64 |
+
# D matrix stabilization
|
65 |
+
if stabilize_rowwise:
|
66 |
+
max_log_D, _ = torch.max(log_D_matrix, dim=-1, keepdim=True) # (B, NH, S, 1)
|
67 |
+
else:
|
68 |
+
max_log_D = torch.max(log_D_matrix.view(B, NH, -1), dim=-1, keepdim=True)[
|
69 |
+
0
|
70 |
+
].unsqueeze(-1)
|
71 |
+
# (B, NH, 1, 1)
|
72 |
+
log_D_matrix_stabilized = log_D_matrix - max_log_D # (B, NH, S, S)
|
73 |
+
D_matrix = torch.exp(log_D_matrix_stabilized) # (B, NH, S, S)
|
74 |
+
|
75 |
+
keys_scaled = keys / math.sqrt(DH)
|
76 |
+
|
77 |
+
# combination matrix C
|
78 |
+
qk_matrix = queries @ keys_scaled.transpose(-2, -1) # (B, NH, S, S)
|
79 |
+
C_matrix = qk_matrix * D_matrix # (B, NH, S, S)
|
80 |
+
normalizer = torch.maximum(
|
81 |
+
C_matrix.sum(dim=-1, keepdim=True).abs(), torch.exp(-max_log_D)
|
82 |
+
) # (B, NH, S, 1)
|
83 |
+
# (B, NH, S, S)
|
84 |
+
C_matrix_normalized = C_matrix / (normalizer + eps)
|
85 |
+
|
86 |
+
# retrieved values
|
87 |
+
h_tilde_state = C_matrix_normalized @ values # (B, NH, S, DH)
|
88 |
+
|
89 |
+
return h_tilde_state
|
90 |
+
|
91 |
+
|
92 |
+
def chunkwise_simple(
|
93 |
+
queries: torch.Tensor,
|
94 |
+
keys: torch.Tensor, # B, NH, S, DH
|
95 |
+
values: torch.Tensor, # B, NH, S, DH
|
96 |
+
igate_preact: torch.Tensor, # B, NH, S
|
97 |
+
fgate_preact: torch.Tensor, # B, NH, S
|
98 |
+
initial_C: Optional[torch.Tensor] = None, # B, NH, DH, DH
|
99 |
+
initial_n: Optional[torch.Tensor] = None, # B, NH, DH, 1
|
100 |
+
initial_m: Optional[torch.Tensor] = None, # B, NH, 1, 1
|
101 |
+
chunk_size: int = 64, # optimize this
|
102 |
+
return_last_state: bool = False,
|
103 |
+
eps: float = 1e-6,
|
104 |
+
**kwargs,
|
105 |
+
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
106 |
+
B, NH, S, DH = queries.shape
|
107 |
+
NS, CS = S // chunk_size, chunk_size
|
108 |
+
_dtype, _device = queries.dtype, queries.device
|
109 |
+
|
110 |
+
# form chunks
|
111 |
+
q = queries.view(B, NH, NS, CS, DH)
|
112 |
+
k = keys.view(B, NH, NS, CS, DH) / math.sqrt(DH)
|
113 |
+
v = values.view(B, NH, NS, CS, DH)
|
114 |
+
|
115 |
+
# forget gates
|
116 |
+
log_fgates = torch.nn.functional.logsigmoid(fgate_preact).view(B, NH, NS, CS)
|
117 |
+
log_fgates_acc = log_fgates.cumsum(dim=3)
|
118 |
+
igate_preact = igate_preact.view(B, NH, NS, CS)
|
119 |
+
|
120 |
+
log_fgates_rep = log_fgates[:, :, :, :, None].repeat(1, 1, 1, 1, CS)
|
121 |
+
log_fg_matrix = torch.tril(log_fgates_rep, diagonal=-1)
|
122 |
+
log_prod_fg_matrix = torch.cumsum(log_fg_matrix, dim=3)
|
123 |
+
|
124 |
+
loggates = (igate_preact + log_prod_fg_matrix[:, :, :, -1]).unsqueeze(-1)
|
125 |
+
m_loc, _ = torch.max(loggates, dim=3, keepdim=True)
|
126 |
+
loggates = loggates - m_loc
|
127 |
+
|
128 |
+
kv = k.transpose(-1, -2) @ (v * (loggates).exp())
|
129 |
+
ksum = (k * (loggates).exp()).sum(dim=-2)
|
130 |
+
C = torch.zeros((B, NH, NS + 1, DH, DH), device=kv.device, dtype=kv.dtype)
|
131 |
+
n = torch.zeros((B, NH, NS + 1, DH, 1), device=kv.device, dtype=kv.dtype)
|
132 |
+
if initial_C is not None:
|
133 |
+
C[:, :, 0] = initial_C
|
134 |
+
if initial_n is not None:
|
135 |
+
n[:, :, 0] = initial_n
|
136 |
+
|
137 |
+
m = torch.zeros((B, NH, NS + 1, 1, 1), device=kv.device, dtype=kv.dtype)
|
138 |
+
if initial_m is not None:
|
139 |
+
m[:, :, 0] = initial_m
|
140 |
+
|
141 |
+
for i in range(1, NS + 1):
|
142 |
+
m[:, :, i] = torch.maximum(
|
143 |
+
log_fgates_acc[:, :, i - 1, -1, None, None] + m[:, :, i - 1],
|
144 |
+
m_loc[:, :, i - 1],
|
145 |
+
)
|
146 |
+
C[:, :, i] = (
|
147 |
+
C[:, :, i - 1].clone()
|
148 |
+
* (
|
149 |
+
log_fgates_acc[:, :, i - 1, -1, None, None]
|
150 |
+
+ m[:, :, i - 1]
|
151 |
+
- m[:, :, i]
|
152 |
+
).exp()
|
153 |
+
+ kv[:, :, i - 1] * (m_loc[:, :, i - 1] - m[:, :, i]).exp()
|
154 |
+
)
|
155 |
+
n[:, :, i] = (
|
156 |
+
n[:, :, i - 1].clone()
|
157 |
+
* (
|
158 |
+
log_fgates_acc[:, :, i - 1, None, -1:]
|
159 |
+
+ m[:, :, i - 1]
|
160 |
+
- m[:, :, i]
|
161 |
+
).exp()
|
162 |
+
+ ksum[:, :, i - 1, :, None] * (m_loc[:, :, i - 1] - m[:, :, i]).exp()
|
163 |
+
)
|
164 |
+
|
165 |
+
log_fg_matrix = log_prod_fg_matrix - torch.triu(
|
166 |
+
torch.full([1, 1, 1, CS, CS], float("inf")).to(q), diagonal=1
|
167 |
+
)
|
168 |
+
|
169 |
+
# gate decay matrix D (combination of forget gate and input gate)
|
170 |
+
log_D_matrix = log_fg_matrix + igate_preact[:, :, :, :, None].transpose(
|
171 |
+
-2, -1
|
172 |
+
) # (B, NH, NS, CS, CS)
|
173 |
+
D_max, _ = torch.max(log_D_matrix, dim=-1, keepdim=True)
|
174 |
+
|
175 |
+
stab = torch.maximum(D_max, m[:, :, :-1, :] + log_fgates_acc[:, :, :, :, None])
|
176 |
+
inter_C = (
|
177 |
+
q * (m[:, :, :-1, :] + log_fgates_acc[:, :, :, :, None] - stab).exp()
|
178 |
+
) @ C[:, :, :-1]
|
179 |
+
inter_n = (
|
180 |
+
q * (m[:, :, :-1, :] + log_fgates_acc[:, :, :, :, None] - stab).exp()
|
181 |
+
) @ n[:, :, :-1, :]
|
182 |
+
|
183 |
+
# D matrix stabilization
|
184 |
+
log_D_matrix_stabilized = log_D_matrix - stab # (B, NH, NS, CS, CS)
|
185 |
+
D_matrix = torch.exp(log_D_matrix_stabilized) # (B, NH, NS, CS, CS)
|
186 |
+
|
187 |
+
# combination matrix C
|
188 |
+
qk_matrix = q @ k.transpose(-2, -1) # (B, NH, NS, CS, CS)
|
189 |
+
E_matrix = qk_matrix * D_matrix # (B, NH, NS, CS, CS)
|
190 |
+
|
191 |
+
normalizer = torch.maximum(
|
192 |
+
(E_matrix.sum(dim=-1, keepdim=True) + inter_n).abs(),
|
193 |
+
torch.exp(-stab),
|
194 |
+
) # (B, NH, NS, CS, 1)
|
195 |
+
|
196 |
+
E_matrix_normalized = E_matrix / (normalizer + eps)
|
197 |
+
|
198 |
+
# retrieved values
|
199 |
+
intra = E_matrix_normalized @ v # (B, NH, S, DH)
|
200 |
+
inter = inter_C / (normalizer + eps)
|
201 |
+
|
202 |
+
if return_last_state:
|
203 |
+
return (intra + inter).view((B, NH, S, DH)), (C[:, :, -1], n[:, :, -1], m[:, :, -1])
|
204 |
+
else:
|
205 |
+
return (intra + inter).view((B, NH, S, DH))
|
206 |
+
|
207 |
+
|
208 |
+
# chunkwise backend adapted to handle inputs which are not cleanly divisible by chunk_size
|
209 |
+
def chunkwise_variable(
|
210 |
+
queries: torch.Tensor,
|
211 |
+
keys: torch.Tensor,
|
212 |
+
values: torch.Tensor,
|
213 |
+
igate_preact: torch.Tensor,
|
214 |
+
fgate_preact: torch.Tensor,
|
215 |
+
initial_C: Optional[torch.Tensor] = None,
|
216 |
+
initial_n: Optional[torch.Tensor] = None,
|
217 |
+
initial_m: Optional[torch.Tensor] = None,
|
218 |
+
chunk_size: int = 64,
|
219 |
+
return_last_state: bool = False,
|
220 |
+
eps: float = 1e-6,
|
221 |
+
**kwargs,
|
222 |
+
) -> Union[
|
223 |
+
torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
|
224 |
+
]:
|
225 |
+
""""
|
226 |
+
Wrapper around chunkwise_simple to allow sequences with arbitrary lengths
|
227 |
+
"""
|
228 |
+
tail_size = queries.shape[-2] % chunk_size
|
229 |
+
if tail_size == 0 or queries.shape[-2] < chunk_size:
|
230 |
+
return chunkwise_simple(
|
231 |
+
queries, keys, values, igate_preact, fgate_preact,
|
232 |
+
initial_C, initial_n, initial_m, chunk_size if tail_size == 0 else tail_size,
|
233 |
+
return_last_state, eps, **kwargs
|
234 |
+
)
|
235 |
+
|
236 |
+
sections = [queries.shape[-2] - tail_size, tail_size]
|
237 |
+
head_args, tail_args = zip(*(torch.split(x, sections, dim=-2) for x in [
|
238 |
+
queries, keys, values, igate_preact, fgate_preact
|
239 |
+
]))
|
240 |
+
head_out, state = chunkwise_simple(
|
241 |
+
*head_args, initial_C, initial_n, initial_m,
|
242 |
+
chunk_size=chunk_size, return_last_state=True, eps=eps, **kwargs
|
243 |
+
)
|
244 |
+
tail_out = chunkwise_simple(
|
245 |
+
*tail_args, *state, chunk_size=tail_size,
|
246 |
+
return_last_state=return_last_state, eps=eps, **kwargs
|
247 |
+
)
|
248 |
+
|
249 |
+
if return_last_state:
|
250 |
+
return torch.cat([head_out, tail_out[0]], dim=-2), tail_out[-1]
|
251 |
+
else:
|
252 |
+
return torch.cat([head_out, tail_out], dim=-2)
|
253 |
+
|
254 |
+
|
255 |
+
def recurrent_step_stabilized_simple(
|
256 |
+
c_state: torch.Tensor,
|
257 |
+
n_state: torch.Tensor,
|
258 |
+
m_state: torch.Tensor,
|
259 |
+
q: torch.Tensor,
|
260 |
+
k: torch.Tensor,
|
261 |
+
v: torch.Tensor,
|
262 |
+
igate_preact: torch.Tensor,
|
263 |
+
fgate_preact: torch.Tensor,
|
264 |
+
eps: float = 1e-6,
|
265 |
+
**kwargs,
|
266 |
+
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
267 |
+
"""This is a single step of the mLSTM operation in recurrent form.
|
268 |
+
|
269 |
+
Args:
|
270 |
+
c_state (torch.Tensor): (B, NH, DH, DH)
|
271 |
+
n_state (torch.Tensor): (B, NH, DH, 1)
|
272 |
+
m_state (torch.Tensor): (B, NH, 1, 1)
|
273 |
+
q (torch.Tensor): (B, NH, 1, DH)
|
274 |
+
k (torch.Tensor): (B, NH, 1, DH)
|
275 |
+
v (torch.Tensor): (B, NH, 1, DH)
|
276 |
+
igate_preact (torch.Tensor): (B, NH, 1, 1)
|
277 |
+
fgate_preact (torch.Tensor): (B, NH, 1, 1)
|
278 |
+
|
279 |
+
Returns:
|
280 |
+
tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
281 |
+
(hidden_state [B, NH, DH], (c_state_new [B, NH, DH, DH], n_state_new [B, NH, DH, 1]], m_state_new [B, NH, 1, 1]))
|
282 |
+
"""
|
283 |
+
B, NH, S, DH = q.shape
|
284 |
+
# projections
|
285 |
+
q, k, v = (
|
286 |
+
q.squeeze_(2).unsqueeze(-1),
|
287 |
+
k.squeeze_(2).unsqueeze(-1),
|
288 |
+
v.squeeze_(2).unsqueeze(-1),
|
289 |
+
) # (B, NH, DH, 1)
|
290 |
+
|
291 |
+
# gates
|
292 |
+
log_fg_act = torch.nn.functional.logsigmoid(fgate_preact) # (B, NH, 1, 1)
|
293 |
+
|
294 |
+
# update rule
|
295 |
+
m_state_new = torch.max(log_fg_act + m_state, igate_preact) # (B, NH, 1, 1)
|
296 |
+
|
297 |
+
fg_act = torch.exp(log_fg_act + m_state - m_state_new) # (B, NH, 1, 1)
|
298 |
+
ig_act = torch.exp(igate_preact - m_state_new) # (B, NH, 1, 1)
|
299 |
+
|
300 |
+
k_scaled = k / math.sqrt(DH)
|
301 |
+
|
302 |
+
c_state_new = fg_act * c_state + ig_act * (
|
303 |
+
k_scaled @ v.transpose(-1, -2)
|
304 |
+
) # (B, NH, DH, DH)
|
305 |
+
n_state_new = fg_act * n_state + ig_act * k_scaled # (B, NH, DH, 1)
|
306 |
+
|
307 |
+
h_num = q.transpose(-1, -2) @ c_state_new # (B, NH, 1, DH)
|
308 |
+
|
309 |
+
qn_dotproduct = q.transpose(-1, -2) @ n_state_new # (B, NH, 1, 1)
|
310 |
+
max_val = torch.exp(-m_state_new) # (B, NH, 1, 1)
|
311 |
+
h_denom = torch.maximum(qn_dotproduct.abs(), max_val) + eps
|
312 |
+
h = h_num / h_denom # (B, NH, 1, DH) / (B, NH, 1, 1) = (B, NH, 1, DH)
|
313 |
+
|
314 |
+
return h, (c_state_new, n_state_new, m_state_new)
|
protxlstm/xlstm/blocks/mlstm/block.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) NXAI GmbH and its affiliates 2024
|
2 |
+
# Maximilian Beck
|
3 |
+
|
4 |
+
# Modified by Pieter-Jan Hoedt, Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen
|
5 |
+
# - Remove sLSTM
|
6 |
+
|
7 |
+
|
8 |
+
from dataclasses import dataclass, field
|
9 |
+
|
10 |
+
from ..xlstm_block import xLSTMBlock, xLSTMBlockConfig
|
11 |
+
from .layer import mLSTMLayerConfig
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class mLSTMBlockConfig:
|
16 |
+
mlstm: mLSTMLayerConfig = field(default_factory=mLSTMLayerConfig)
|
17 |
+
|
18 |
+
def __post_init__(self):
|
19 |
+
self.mlstm.__post_init__()
|
20 |
+
|
21 |
+
|
22 |
+
class mLSTMBlock(xLSTMBlock):
|
23 |
+
|
24 |
+
config_class = mLSTMBlockConfig
|
25 |
+
|
26 |
+
def __init__(self, config: mLSTMBlockConfig) -> None:
|
27 |
+
super().__init__(config=xLSTMBlockConfig(mlstm=config.mlstm, feedforward=None))
|
protxlstm/xlstm/blocks/mlstm/cell.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) NXAI GmbH and its affiliates 2024
|
2 |
+
# Maximilian Beck
|
3 |
+
|
4 |
+
# Modified by Pieter-Jan Hoedt, Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen
|
5 |
+
# - Add references to chunkwise backends
|
6 |
+
# - Modify forward to take and return state
|
7 |
+
|
8 |
+
|
9 |
+
from dataclasses import dataclass
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
from functools import partial
|
14 |
+
|
15 |
+
from ...components.init import bias_linspace_init_
|
16 |
+
from ...components.ln import MultiHeadLayerNorm
|
17 |
+
from .backends import parallel_stabilized_simple, chunkwise_simple, chunkwise_variable, recurrent_step_stabilized_simple
|
18 |
+
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class mLSTMCellConfig:
|
22 |
+
context_length: int = -1
|
23 |
+
embedding_dim: int = -1
|
24 |
+
num_heads: int = -1
|
25 |
+
backend: str = "parallel" # "chunkwise"
|
26 |
+
chunk_size: int = 64
|
27 |
+
return_last_state: bool = False
|
28 |
+
|
29 |
+
|
30 |
+
class mLSTMCell(nn.Module):
|
31 |
+
config_class = mLSTMCellConfig
|
32 |
+
|
33 |
+
def __init__(self, config: mLSTMCellConfig):
|
34 |
+
super().__init__()
|
35 |
+
self.config = config
|
36 |
+
|
37 |
+
if self.config.return_last_state == True:
|
38 |
+
assert config.backend != "parallel", "Parallel backend cannot return state - set return_last_state to False or use a chunkwise backend."
|
39 |
+
|
40 |
+
if config.backend == "parallel":
|
41 |
+
self.backend_fn = parallel_stabilized_simple
|
42 |
+
elif config.backend == "chunkwise":
|
43 |
+
chunkwise_backend = partial(chunkwise_simple, chunk_size=config.chunk_size, return_last_state=config.return_last_state)
|
44 |
+
self.backend_fn = chunkwise_backend
|
45 |
+
elif config.backend == "chunkwise_variable":
|
46 |
+
chunkwise_backend = partial(chunkwise_variable, chunk_size=config.chunk_size, return_last_state=config.return_last_state)
|
47 |
+
self.backend_fn = chunkwise_backend
|
48 |
+
else:
|
49 |
+
raise ValueError(f"Unknown mLSTM backend: {config.backend}")
|
50 |
+
self.backend_fn_step = recurrent_step_stabilized_simple
|
51 |
+
|
52 |
+
self.igate = nn.Linear(3 * config.embedding_dim, config.num_heads)
|
53 |
+
self.fgate = nn.Linear(3 * config.embedding_dim, config.num_heads)
|
54 |
+
|
55 |
+
self.outnorm = MultiHeadLayerNorm(ndim=config.embedding_dim, weight=True, bias=False)
|
56 |
+
|
57 |
+
if config.backend == "parallel":
|
58 |
+
self.register_buffer(
|
59 |
+
"causal_mask",
|
60 |
+
torch.tril(torch.ones(config.context_length, config.context_length, dtype=torch.bool)),
|
61 |
+
persistent=False,
|
62 |
+
)
|
63 |
+
else:
|
64 |
+
self.causal_mask = None
|
65 |
+
|
66 |
+
self.reset_parameters()
|
67 |
+
|
68 |
+
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, state = None, **kwargs) -> torch.Tensor:
|
69 |
+
B, S, _ = q.shape # (B, S, H)
|
70 |
+
|
71 |
+
if_gate_input = torch.cat([q, k, v], dim=-1)
|
72 |
+
q = q.view(B, S, self.config.num_heads, -1) # (B, S, NH, DH)
|
73 |
+
k = k.view(B, S, self.config.num_heads, -1) # (B, S, NH, DH)
|
74 |
+
v = v.view(B, S, self.config.num_heads, -1) # (B, S, NH, DH)
|
75 |
+
|
76 |
+
q = q.transpose(1, 2) # (B, NH, S, DH)
|
77 |
+
k = k.transpose(1, 2) # (B, NH, S, DH)
|
78 |
+
v = v.transpose(1, 2) # (B, NH, S, DH)
|
79 |
+
|
80 |
+
# compute input and forget gate pre-activations
|
81 |
+
igate_preact = self.igate(if_gate_input) # (B, S, NH)
|
82 |
+
igate_preact = igate_preact.transpose(-1, -2).unsqueeze(-1) # (B, NH, S, 1)
|
83 |
+
fgate_preact = self.fgate(if_gate_input) # (B, S, NH)
|
84 |
+
fgate_preact = fgate_preact.transpose(-1, -2).unsqueeze(-1) # (B, NH, S, 1)#
|
85 |
+
|
86 |
+
if state != None and self.config.backend in ["chunkwise", "chunkwise_variable"]:
|
87 |
+
|
88 |
+
initial_C, initial_n, initial_m = state
|
89 |
+
|
90 |
+
if self.config.return_last_state:
|
91 |
+
|
92 |
+
h_state, mlstm_state = self.backend_fn(
|
93 |
+
queries=q,
|
94 |
+
keys=k,
|
95 |
+
values=v,
|
96 |
+
igate_preact=igate_preact,
|
97 |
+
fgate_preact=fgate_preact,
|
98 |
+
initial_C=initial_C,
|
99 |
+
initial_n=initial_n,
|
100 |
+
initial_m=initial_m,
|
101 |
+
lower_triangular_matrix=self.causal_mask,
|
102 |
+
)
|
103 |
+
|
104 |
+
else:
|
105 |
+
h_state = self.backend_fn(
|
106 |
+
queries=q,
|
107 |
+
keys=k,
|
108 |
+
values=v,
|
109 |
+
igate_preact=igate_preact,
|
110 |
+
fgate_preact=fgate_preact,
|
111 |
+
initial_C=initial_C,
|
112 |
+
initial_n=initial_n,
|
113 |
+
initial_m=initial_m,
|
114 |
+
lower_triangular_matrix=self.causal_mask,
|
115 |
+
) # (B, NH, S, DH)
|
116 |
+
|
117 |
+
else:
|
118 |
+
if self.config.return_last_state:
|
119 |
+
h_state, mlstm_state = self.backend_fn(
|
120 |
+
queries=q,
|
121 |
+
keys=k,
|
122 |
+
values=v,
|
123 |
+
igate_preact=igate_preact,
|
124 |
+
fgate_preact=fgate_preact,
|
125 |
+
lower_triangular_matrix=self.causal_mask,
|
126 |
+
)
|
127 |
+
|
128 |
+
else:
|
129 |
+
h_state = self.backend_fn(
|
130 |
+
queries=q,
|
131 |
+
keys=k,
|
132 |
+
values=v,
|
133 |
+
igate_preact=igate_preact,
|
134 |
+
fgate_preact=fgate_preact,
|
135 |
+
lower_triangular_matrix=self.causal_mask,
|
136 |
+
) # (B, NH, S, DH)
|
137 |
+
|
138 |
+
h_state_norm = self.outnorm(h_state) # (B, NH, S, DH)
|
139 |
+
h_state_norm = h_state_norm.transpose(1, 2).reshape(B, S, -1) # (B, NH, S, DH) -> (B, S, NH, DH) -> (B, S, H)
|
140 |
+
|
141 |
+
if self.config.return_last_state:
|
142 |
+
return h_state_norm, mlstm_state
|
143 |
+
else:
|
144 |
+
return h_state_norm
|
145 |
+
|
146 |
+
def step(
|
147 |
+
self,
|
148 |
+
q: torch.Tensor,
|
149 |
+
k: torch.Tensor,
|
150 |
+
v: torch.Tensor,
|
151 |
+
mlstm_state: tuple[torch.Tensor, torch.Tensor, torch.Tensor] = None,
|
152 |
+
**kwargs,
|
153 |
+
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
154 |
+
B, S, _ = q.shape # (B, S, H)
|
155 |
+
assert S == 1, f"mLSTMCell.step only supports sequence length S=1, but got S={S}."
|
156 |
+
|
157 |
+
if_gate_input = torch.cat([q, k, v], dim=-1)
|
158 |
+
q = q.view(B, S, self.config.num_heads, -1) # (B, S, NH, DH)
|
159 |
+
k = k.view(B, S, self.config.num_heads, -1) # (B, S, NH, DH)
|
160 |
+
v = v.view(B, S, self.config.num_heads, -1) # (B, S, NH, DH)
|
161 |
+
|
162 |
+
_, _, NH, DH = q.shape
|
163 |
+
|
164 |
+
q = q.transpose(1, 2) # (B, NH, S, DH)
|
165 |
+
k = k.transpose(1, 2) # (B, NH, S, DH)
|
166 |
+
v = v.transpose(1, 2) # (B, NH, S, DH)
|
167 |
+
|
168 |
+
# compute input and forget gate pre-activations
|
169 |
+
igate_preact = self.igate(if_gate_input) # (B, S, NH)
|
170 |
+
igate_preact = igate_preact.transpose(-1, -2).unsqueeze(-1) # (B, NH, S, 1)
|
171 |
+
fgate_preact = self.fgate(if_gate_input) # (B, S, NH)
|
172 |
+
fgate_preact = fgate_preact.transpose(-1, -2).unsqueeze(-1) # (B, NH, S, 1)
|
173 |
+
|
174 |
+
if mlstm_state is None:
|
175 |
+
c_state = torch.zeros(size=(B, NH, DH, DH), device=q.device, dtype=q.dtype)
|
176 |
+
n_state = torch.zeros(size=(B, NH, DH, 1), device=q.device, dtype=q.dtype)
|
177 |
+
m_state = torch.zeros(size=(B, NH, 1, 1), device=q.device, dtype=q.dtype)
|
178 |
+
else:
|
179 |
+
c_state, n_state, m_state = mlstm_state
|
180 |
+
c_state = c_state.to(device=q.device, dtype=q.dtype)
|
181 |
+
n_state = n_state.to(device=q.device, dtype=q.dtype)
|
182 |
+
m_state = m_state.to(device=q.device, dtype=q.dtype)
|
183 |
+
|
184 |
+
assert c_state.shape == (B, NH, DH, DH), f"Expected c_state shape {(B, NH, DH, DH)}, but got {c_state.shape}."
|
185 |
+
assert n_state.shape == (B, NH, DH, 1), f"Expected n_state shape {(B, NH, DH, 1)}, but got {n_state.shape}."
|
186 |
+
assert m_state.shape == (B, NH, 1, 1), f"Expected m_state shape {(B, NH, 1, 1)}, but got {m_state.shape}."
|
187 |
+
|
188 |
+
h_state, mlstm_state = self.backend_fn_step(
|
189 |
+
c_state=c_state,
|
190 |
+
n_state=n_state,
|
191 |
+
m_state=m_state,
|
192 |
+
q=q,
|
193 |
+
k=k,
|
194 |
+
v=v,
|
195 |
+
igate_preact=igate_preact,
|
196 |
+
fgate_preact=fgate_preact,
|
197 |
+
) # (B, NH, 1 DH), ((B, NH, DH, DH), (B, NH, DH, 1), (B, NH, 1, 1))
|
198 |
+
|
199 |
+
h_state_norm = self.outnorm(h_state) # (B, NH, S, DH)
|
200 |
+
h_state_norm = h_state_norm.transpose(1, 2).reshape(B, S, -1) # (B, NH, S, DH) -> (B, S, NH, DH) -> (B, S, H)
|
201 |
+
|
202 |
+
return h_state_norm, mlstm_state
|
203 |
+
|
204 |
+
|
205 |
+
def reset_parameters(self):
|
206 |
+
self.outnorm.reset_parameters()
|
207 |
+
# forget gate initialization
|
208 |
+
torch.nn.init.zeros_(self.fgate.weight)
|
209 |
+
bias_linspace_init_(self.fgate.bias, start=3.0, end=6.0)
|
210 |
+
# input gate initialization
|
211 |
+
torch.nn.init.zeros_(self.igate.weight)
|
212 |
+
torch.nn.init.normal_(self.igate.bias, mean=0.0, std=0.1)
|
protxlstm/xlstm/blocks/mlstm/layer.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) NXAI GmbH and its affiliates 2024
|
2 |
+
# Maximilian Beck
|
3 |
+
|
4 |
+
# Modified by Pieter-Jan Hoedt, Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen
|
5 |
+
# - Modify forward to take and return state
|
6 |
+
|
7 |
+
|
8 |
+
from dataclasses import dataclass
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
from ...components.conv import CausalConv1d, CausalConv1dConfig
|
14 |
+
from ...components.init import small_init_init_, wang_init_
|
15 |
+
from ...components.linear_headwise import (
|
16 |
+
LinearHeadwiseExpand,
|
17 |
+
LinearHeadwiseExpandConfig,
|
18 |
+
)
|
19 |
+
from ...utils import UpProjConfigMixin
|
20 |
+
from ...components.rotary_position import apply_rotary_emb
|
21 |
+
from .cell import mLSTMCell, mLSTMCellConfig
|
22 |
+
|
23 |
+
|
24 |
+
@dataclass
|
25 |
+
class mLSTMLayerConfig(UpProjConfigMixin):
|
26 |
+
conv1d_kernel_size: int = 4
|
27 |
+
qkv_proj_blocksize: int = 4
|
28 |
+
num_heads: int = 4
|
29 |
+
proj_factor: float = 2.0
|
30 |
+
|
31 |
+
# will be set toplevel config
|
32 |
+
embedding_dim: int = -1
|
33 |
+
bias: bool = False
|
34 |
+
dropout: float = 0.0
|
35 |
+
context_length: int = -1
|
36 |
+
backend: str = "parallel" # "chunkwise"
|
37 |
+
chunk_size: int = 64
|
38 |
+
return_last_state: bool = False
|
39 |
+
|
40 |
+
_num_blocks: int = 1
|
41 |
+
_inner_embedding_dim: int = None
|
42 |
+
|
43 |
+
def __post_init__(self):
|
44 |
+
self._set_proj_up_dim(embedding_dim=self.embedding_dim)
|
45 |
+
self._inner_embedding_dim = self._proj_up_dim
|
46 |
+
|
47 |
+
|
48 |
+
class mLSTMLayer(nn.Module):
|
49 |
+
config_class = mLSTMLayerConfig
|
50 |
+
|
51 |
+
def __init__(self, config: mLSTMLayerConfig):
|
52 |
+
super().__init__()
|
53 |
+
self.config = config
|
54 |
+
|
55 |
+
self.proj_up = nn.Linear(
|
56 |
+
in_features=self.config.embedding_dim,
|
57 |
+
out_features=2 * self.config._inner_embedding_dim,
|
58 |
+
bias=self.config.bias,
|
59 |
+
)
|
60 |
+
|
61 |
+
num_proj_heads = round(self.config._inner_embedding_dim // self.config.qkv_proj_blocksize)
|
62 |
+
self.q_proj = LinearHeadwiseExpand(
|
63 |
+
config=LinearHeadwiseExpandConfig(
|
64 |
+
in_features=self.config._inner_embedding_dim,
|
65 |
+
num_heads=num_proj_heads,
|
66 |
+
bias=self.config.bias,
|
67 |
+
)
|
68 |
+
)
|
69 |
+
self.k_proj = LinearHeadwiseExpand(
|
70 |
+
config=LinearHeadwiseExpandConfig(
|
71 |
+
in_features=self.config._inner_embedding_dim,
|
72 |
+
num_heads=num_proj_heads,
|
73 |
+
bias=self.config.bias,
|
74 |
+
)
|
75 |
+
)
|
76 |
+
self.v_proj = LinearHeadwiseExpand(
|
77 |
+
config=LinearHeadwiseExpandConfig(
|
78 |
+
in_features=self.config._inner_embedding_dim,
|
79 |
+
num_heads=num_proj_heads,
|
80 |
+
bias=self.config.bias,
|
81 |
+
)
|
82 |
+
)
|
83 |
+
|
84 |
+
self.conv1d = CausalConv1d(
|
85 |
+
config=CausalConv1dConfig(
|
86 |
+
feature_dim=self.config._inner_embedding_dim,
|
87 |
+
kernel_size=self.config.conv1d_kernel_size,
|
88 |
+
)
|
89 |
+
)
|
90 |
+
self.conv_act_fn = nn.SiLU()
|
91 |
+
self.mlstm_cell = mLSTMCell(
|
92 |
+
config=mLSTMCellConfig(
|
93 |
+
context_length=self.config.context_length,
|
94 |
+
embedding_dim=self.config._inner_embedding_dim,
|
95 |
+
num_heads=self.config.num_heads,
|
96 |
+
backend=self.config.backend,
|
97 |
+
chunk_size=self.config.chunk_size,
|
98 |
+
return_last_state = self.config.return_last_state
|
99 |
+
)
|
100 |
+
)
|
101 |
+
self.ogate_act_fn = nn.SiLU()
|
102 |
+
|
103 |
+
self.learnable_skip = nn.Parameter(torch.ones(self.config._inner_embedding_dim, requires_grad=True))
|
104 |
+
|
105 |
+
self.proj_down = nn.Linear(
|
106 |
+
in_features=self.config._inner_embedding_dim,
|
107 |
+
out_features=self.config.embedding_dim,
|
108 |
+
bias=self.config.bias,
|
109 |
+
)
|
110 |
+
self.dropout = nn.Dropout(self.config.dropout)
|
111 |
+
self.reset_parameters()
|
112 |
+
|
113 |
+
def forward(self, x: torch.Tensor, freqs_cos=None, freqs_sin=None, state=None, **kwargs) -> torch.Tensor:
|
114 |
+
B, S, _ = x.shape
|
115 |
+
|
116 |
+
# up-projection
|
117 |
+
x_inner = self.proj_up(x)
|
118 |
+
x_mlstm, z = torch.split(x_inner, split_size_or_sections=self.config._inner_embedding_dim, dim=-1)
|
119 |
+
|
120 |
+
# mlstm branch
|
121 |
+
if state != None:
|
122 |
+
mlstm_state = state["mlstm_state"]
|
123 |
+
conv_state = state["conv_state"][0]
|
124 |
+
else:
|
125 |
+
mlstm_state, conv_state = None, None
|
126 |
+
|
127 |
+
if self.config.return_last_state:
|
128 |
+
x_mlstm_conv, conv_state = self.conv1d(x_mlstm, conv_state = conv_state, return_last_state = True)
|
129 |
+
else:
|
130 |
+
x_mlstm_conv = self.conv1d(x_mlstm, conv_state = conv_state)
|
131 |
+
|
132 |
+
x_mlstm_conv_act = self.conv_act_fn(x_mlstm_conv)
|
133 |
+
|
134 |
+
q = self.q_proj(x_mlstm_conv_act)
|
135 |
+
k = self.k_proj(x_mlstm_conv_act)
|
136 |
+
v = self.v_proj(x_mlstm)
|
137 |
+
|
138 |
+
if freqs_cos is not None and freqs_sin is not None:
|
139 |
+
q, k = apply_rotary_emb(q, k, freqs_cos, freqs_sin)
|
140 |
+
|
141 |
+
if self.config.return_last_state:
|
142 |
+
h_tilde_state, mlstm_state = self.mlstm_cell(q=q, k=k, v=v, state=mlstm_state, **kwargs)
|
143 |
+
else:
|
144 |
+
h_tilde_state = self.mlstm_cell(q=q, k=k, v=v, state=mlstm_state, **kwargs)
|
145 |
+
|
146 |
+
h_tilde_state_skip = h_tilde_state + (self.learnable_skip * x_mlstm_conv_act)
|
147 |
+
|
148 |
+
# output / z branch
|
149 |
+
h_state = h_tilde_state_skip * self.ogate_act_fn(z)
|
150 |
+
|
151 |
+
# down-projection
|
152 |
+
y = self.dropout(self.proj_down(h_state))
|
153 |
+
|
154 |
+
if self.config.return_last_state:
|
155 |
+
return y, {"mlstm_state": mlstm_state, "conv_state": (conv_state,)}
|
156 |
+
else:
|
157 |
+
return y
|
158 |
+
|
159 |
+
def step(
|
160 |
+
self,
|
161 |
+
x: torch.Tensor,
|
162 |
+
freqs_cos=None,
|
163 |
+
freqs_sin=None,
|
164 |
+
mlstm_state: tuple[torch.Tensor, torch.Tensor, torch.Tensor] = None,
|
165 |
+
conv_state: tuple[torch.Tensor] = None,
|
166 |
+
) -> tuple[torch.Tensor, dict[str, tuple[torch.Tensor, ...]]]:
|
167 |
+
B, S, _ = x.shape
|
168 |
+
|
169 |
+
# up-projection
|
170 |
+
x_inner = self.proj_up(x)
|
171 |
+
x_mlstm, z = torch.split(x_inner, split_size_or_sections=self.config._inner_embedding_dim, dim=-1)
|
172 |
+
|
173 |
+
# mlstm branch
|
174 |
+
x_mlstm_conv, conv_state = self.conv1d.step(x_mlstm, conv_state=conv_state)
|
175 |
+
x_mlstm_conv_act = self.conv_act_fn(x_mlstm_conv)
|
176 |
+
|
177 |
+
q = self.q_proj(x_mlstm_conv_act)
|
178 |
+
k = self.k_proj(x_mlstm_conv_act)
|
179 |
+
v = self.v_proj(x_mlstm)
|
180 |
+
|
181 |
+
if freqs_cos is not None and freqs_sin is not None:
|
182 |
+
q, k = apply_rotary_emb(q, k, freqs_cos, freqs_sin)
|
183 |
+
|
184 |
+
h_tilde_state, mlstm_state = self.mlstm_cell.step(q=q, k=k, v=v, mlstm_state=mlstm_state)
|
185 |
+
|
186 |
+
h_tilde_state_skip = h_tilde_state + (self.learnable_skip * x_mlstm_conv_act)
|
187 |
+
|
188 |
+
# output / z branch
|
189 |
+
h_state = h_tilde_state_skip * self.ogate_act_fn(z)
|
190 |
+
|
191 |
+
# down-projection
|
192 |
+
y = self.dropout(self.proj_down(h_state))
|
193 |
+
return y, {"mlstm_state": mlstm_state, "conv_state": conv_state}
|
194 |
+
|
195 |
+
def reset_parameters(self):
|
196 |
+
# init inproj
|
197 |
+
small_init_init_(self.proj_up.weight, dim=self.config.embedding_dim)
|
198 |
+
if self.proj_up.bias is not None:
|
199 |
+
nn.init.zeros_(self.proj_up.bias)
|
200 |
+
# init outproj
|
201 |
+
wang_init_(self.proj_down.weight, dim=self.config.embedding_dim, num_blocks=self.config._num_blocks)
|
202 |
+
if self.proj_down.bias is not None:
|
203 |
+
nn.init.zeros_(self.proj_down.bias)
|
204 |
+
|
205 |
+
nn.init.ones_(self.learnable_skip)
|
206 |
+
|
207 |
+
def _init_qkv_proj(qkv_proj: LinearHeadwiseExpand):
|
208 |
+
# use the embedding dim instead of the inner embedding dim
|
209 |
+
small_init_init_(qkv_proj.weight, dim=self.config.embedding_dim)
|
210 |
+
if qkv_proj.bias is not None:
|
211 |
+
nn.init.zeros_(qkv_proj.bias)
|
212 |
+
|
213 |
+
_init_qkv_proj(self.q_proj)
|
214 |
+
_init_qkv_proj(self.k_proj)
|
215 |
+
_init_qkv_proj(self.v_proj)
|
216 |
+
|
217 |
+
self.mlstm_cell.reset_parameters()
|
protxlstm/xlstm/blocks/xlstm_block.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) NXAI GmbH and its affiliates 2024
|
2 |
+
# Maximilian Beck
|
3 |
+
|
4 |
+
# Modified by Pieter-Jan Hoedt, Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen
|
5 |
+
# - Remove sLSTM
|
6 |
+
# - Modify forward to take and return state
|
7 |
+
|
8 |
+
|
9 |
+
from dataclasses import dataclass
|
10 |
+
from typing import Optional
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch import nn
|
14 |
+
|
15 |
+
from ..components.feedforward import FeedForwardConfig, create_feedforward
|
16 |
+
from ..components.ln import LayerNorm
|
17 |
+
from .mlstm.layer import mLSTMLayer, mLSTMLayerConfig
|
18 |
+
|
19 |
+
"""An xLSTM block can be either an sLSTM Block or an mLSTM Block.
|
20 |
+
In this repository only mLSTM is implemented.
|
21 |
+
|
22 |
+
It contains the pre-LayerNorms and the skip connections.
|
23 |
+
"""
|
24 |
+
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class xLSTMBlockConfig:
|
28 |
+
mlstm: Optional[mLSTMLayerConfig] = None
|
29 |
+
|
30 |
+
feedforward: Optional[FeedForwardConfig] = None
|
31 |
+
|
32 |
+
_num_blocks: int = 1
|
33 |
+
_block_idx: int = 0
|
34 |
+
|
35 |
+
def __post_init__(self):
|
36 |
+
assert (
|
37 |
+
self.mlstm is not None
|
38 |
+
), "mlstm config must be provided"
|
39 |
+
|
40 |
+
embedding_dim = (
|
41 |
+
self.mlstm.embedding_dim
|
42 |
+
)
|
43 |
+
|
44 |
+
self.mlstm._num_blocks = self._num_blocks
|
45 |
+
self.mlstm._block_idx = self._block_idx
|
46 |
+
|
47 |
+
if self.feedforward:
|
48 |
+
self.feedforward.embedding_dim = embedding_dim
|
49 |
+
self.feedforward._num_blocks = self._num_blocks
|
50 |
+
self.feedforward.__post_init__()
|
51 |
+
|
52 |
+
|
53 |
+
class xLSTMBlock(nn.Module):
|
54 |
+
|
55 |
+
config_class = xLSTMBlockConfig
|
56 |
+
|
57 |
+
def __init__(self, config: xLSTMBlockConfig) -> None:
|
58 |
+
super().__init__()
|
59 |
+
self.config = config
|
60 |
+
embedding_dim = (
|
61 |
+
self.config.mlstm.embedding_dim
|
62 |
+
)
|
63 |
+
|
64 |
+
self.xlstm_norm = LayerNorm(ndim=embedding_dim, weight=True, bias=False)
|
65 |
+
|
66 |
+
if self.config.mlstm is not None:
|
67 |
+
self.xlstm = mLSTMLayer(config=self.config.mlstm)
|
68 |
+
else:
|
69 |
+
raise ValueError("mlstm must be provided")
|
70 |
+
|
71 |
+
if self.config.feedforward is not None:
|
72 |
+
self.ffn_norm = LayerNorm(
|
73 |
+
ndim=self.config.feedforward.embedding_dim, weight=True, bias=False
|
74 |
+
)
|
75 |
+
self.ffn = create_feedforward(config=self.config.feedforward)
|
76 |
+
else:
|
77 |
+
self.ffn_norm = None
|
78 |
+
self.ffn = None
|
79 |
+
|
80 |
+
self.reset_parameters()
|
81 |
+
|
82 |
+
def forward(self, x: torch.Tensor, state=None, **kwargs) -> torch.Tensor:
|
83 |
+
if self.config.mlstm.return_last_state:
|
84 |
+
x_xlstm, xlstm_state = self.xlstm(self.xlstm_norm(x), state=state, **kwargs)
|
85 |
+
x = x + x_xlstm
|
86 |
+
else:
|
87 |
+
x = x + self.xlstm(self.xlstm_norm(x), state=state, **kwargs)
|
88 |
+
|
89 |
+
if self.ffn is not None:
|
90 |
+
x = x + self.ffn(self.ffn_norm(x), **kwargs)
|
91 |
+
|
92 |
+
if self.config.mlstm.return_last_state:
|
93 |
+
return x, xlstm_state
|
94 |
+
else:
|
95 |
+
return x
|
96 |
+
|
97 |
+
def step(self, x: torch.Tensor, **kwargs) -> tuple[torch.Tensor, dict[str, tuple[torch.Tensor, ...]]]:
|
98 |
+
x_xlstm, xlstm_state = self.xlstm.step(self.xlstm_norm(x), **kwargs)
|
99 |
+
x = x + x_xlstm
|
100 |
+
if self.ffn is not None:
|
101 |
+
x = x + self.ffn(self.ffn_norm(x), **kwargs)
|
102 |
+
return x, xlstm_state
|
103 |
+
|
104 |
+
def reset_parameters(self) -> None:
|
105 |
+
|
106 |
+
self.xlstm.reset_parameters()
|
107 |
+
self.xlstm_norm.reset_parameters()
|
108 |
+
|
109 |
+
if self.ffn is not None:
|
110 |
+
self.ffn.reset_parameters()
|
111 |
+
self.ffn_norm.reset_parameters()
|
protxlstm/xlstm/components/__init__.py
ADDED
File without changes
|
protxlstm/xlstm/components/conv.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) NXAI GmbH and its affiliates 2024
|
2 |
+
# Maximilian Beck, Korbinian Pöppel
|
3 |
+
|
4 |
+
# Modified by Pieter-Jan Hoedt, Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen
|
5 |
+
# - Modify forward to take and return state
|
6 |
+
|
7 |
+
|
8 |
+
from dataclasses import dataclass, field
|
9 |
+
from typing import Optional
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
# from einops import rearrange
|
14 |
+
from torch import nn
|
15 |
+
|
16 |
+
|
17 |
+
@dataclass
|
18 |
+
class CausalConv1dConfig:
|
19 |
+
feature_dim: int = None # F
|
20 |
+
kernel_size: int = 4
|
21 |
+
causal_conv_bias: bool = True
|
22 |
+
channel_mixing: bool = False
|
23 |
+
conv1d_kwargs: dict = field(default_factory=dict)
|
24 |
+
|
25 |
+
def __post_init__(self):
|
26 |
+
assert self.kernel_size >= 0, "kernel_size must be >= 0"
|
27 |
+
|
28 |
+
|
29 |
+
def conv1d_step(
|
30 |
+
x: torch.Tensor,
|
31 |
+
conv_state: torch.Tensor,
|
32 |
+
conv1d_weight: torch.Tensor,
|
33 |
+
conv1d_bias: torch.Tensor = None,
|
34 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
35 |
+
"""
|
36 |
+
B: batch size
|
37 |
+
S: sequence length
|
38 |
+
D: feature dimension
|
39 |
+
KS: kernel size
|
40 |
+
Args:
|
41 |
+
x (torch.Tensor): (B, S, D)
|
42 |
+
conv_state (torch.Tensor): (B, KS, D)
|
43 |
+
conv1d_weight (torch.Tensor): (KS, D)
|
44 |
+
"""
|
45 |
+
assert (
|
46 |
+
x.shape[0] == conv_state.shape[0]
|
47 |
+
), f"x has batch size {x.shape[0]} but conv_state has batch size {conv_state.shape[0]}"
|
48 |
+
assert (
|
49 |
+
x.shape[2] == conv_state.shape[2]
|
50 |
+
), f"x has feature dimension {x.shape[2]} but conv_state has feature dimension {conv_state.shape[2]}"
|
51 |
+
assert x.shape[1] == 1, f"x has sequence length {x.shape[1]} but it should be 1"
|
52 |
+
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=1))
|
53 |
+
conv_state[:, -1:, :] = x
|
54 |
+
y = torch.sum(conv_state * conv1d_weight, dim=1, keepdim=True)
|
55 |
+
if conv1d_bias is not None:
|
56 |
+
y += conv1d_bias
|
57 |
+
return y, conv_state
|
58 |
+
|
59 |
+
|
60 |
+
class CausalConv1d(nn.Module):
|
61 |
+
config_class = CausalConv1dConfig
|
62 |
+
"""
|
63 |
+
Implements causal depthwise convolution of a time series tensor.
|
64 |
+
Input: Tensor of shape (B,T,F), i.e. (batch, time, feature)
|
65 |
+
Output: Tensor of shape (B,T,F)
|
66 |
+
|
67 |
+
Args:
|
68 |
+
feature_dim: number of features in the input tensor
|
69 |
+
kernel_size: size of the kernel for the depthwise convolution
|
70 |
+
causal_conv_bias: whether to use bias in the depthwise convolution
|
71 |
+
channel_mixing: whether to use channel mixing (i.e. groups=1) or not (i.e. groups=feature_dim)
|
72 |
+
If True, it mixes the convolved features across channels.
|
73 |
+
If False, all the features are convolved independently.
|
74 |
+
"""
|
75 |
+
|
76 |
+
def __init__(self, config: CausalConv1dConfig):
|
77 |
+
super().__init__()
|
78 |
+
self.config = config
|
79 |
+
self.groups = self.config.feature_dim
|
80 |
+
if self.config.channel_mixing:
|
81 |
+
self.groups = 1
|
82 |
+
if self.config.kernel_size == 0:
|
83 |
+
self.conv = None # Noop
|
84 |
+
else:
|
85 |
+
self.pad = (
|
86 |
+
self.config.kernel_size - 1
|
87 |
+
) # padding of this size assures temporal causality.
|
88 |
+
self.conv = nn.Conv1d(
|
89 |
+
in_channels=self.config.feature_dim,
|
90 |
+
out_channels=self.config.feature_dim,
|
91 |
+
kernel_size=self.config.kernel_size,
|
92 |
+
padding=self.pad,
|
93 |
+
groups=self.groups,
|
94 |
+
bias=self.config.causal_conv_bias,
|
95 |
+
**self.config.conv1d_kwargs,
|
96 |
+
)
|
97 |
+
# B, C, L
|
98 |
+
self.reset_parameters()
|
99 |
+
|
100 |
+
def reset_parameters(self, **kwargs):
|
101 |
+
self.conv.reset_parameters()
|
102 |
+
|
103 |
+
def _create_weight_decay_optim_groups(
|
104 |
+
self,
|
105 |
+
) -> tuple[set[nn.Parameter], set[nn.Parameter]]:
|
106 |
+
if self.config.kernel_size == 0:
|
107 |
+
return (), ()
|
108 |
+
else:
|
109 |
+
weight_decay = (self.conv.weight,)
|
110 |
+
no_weight_decay = ()
|
111 |
+
if self.config.causal_conv_bias:
|
112 |
+
no_weight_decay += (self.conv.bias,)
|
113 |
+
return weight_decay, no_weight_decay
|
114 |
+
|
115 |
+
def forward(
|
116 |
+
self,
|
117 |
+
x: torch.Tensor,
|
118 |
+
conv_state: Optional[torch.Tensor] = None,
|
119 |
+
return_last_state: bool = False,
|
120 |
+
) -> torch.Tensor:
|
121 |
+
if conv_state is not None:
|
122 |
+
conv_state = conv_state[:,-self.pad:]
|
123 |
+
x = torch.cat([conv_state, x], dim=1)
|
124 |
+
|
125 |
+
if self.config.kernel_size == 0:
|
126 |
+
return x
|
127 |
+
y = x.transpose(2, 1) # (B,F,T) tensor - now in the right shape for conv layer.
|
128 |
+
y = self.conv(y) # (B,F,T+pad) tensor
|
129 |
+
if conv_state is not None:
|
130 |
+
y = y[:, :, conv_state.shape[1] :]
|
131 |
+
|
132 |
+
if return_last_state:
|
133 |
+
return y[:, :, : -self.pad].transpose(2, 1), x[:, -self.config.kernel_size:] #[:, -self.pad :]
|
134 |
+
else:
|
135 |
+
return y[:, :, : -self.pad].transpose(2, 1)
|
136 |
+
|
137 |
+
def step(
|
138 |
+
self,
|
139 |
+
x: torch.Tensor,
|
140 |
+
conv_state: tuple[torch.Tensor] = None,
|
141 |
+
) -> tuple[torch.Tensor, tuple[torch.Tensor]]:
|
142 |
+
|
143 |
+
if self.config.kernel_size == 0:
|
144 |
+
return x, conv_state
|
145 |
+
|
146 |
+
B, S, D = x.shape
|
147 |
+
|
148 |
+
if conv_state is None:
|
149 |
+
conv_state = (
|
150 |
+
torch.zeros(
|
151 |
+
size=(B, self.config.kernel_size, D),
|
152 |
+
device=self.conv.weight.device,
|
153 |
+
dtype=self.conv.weight.dtype,
|
154 |
+
),
|
155 |
+
)
|
156 |
+
|
157 |
+
y, conv_state = conv1d_step(
|
158 |
+
x,
|
159 |
+
conv_state[0],
|
160 |
+
self.conv.weight[:, 0, :].transpose(0, 1), # rearrange(, "D 1 KS -> KS D")
|
161 |
+
conv1d_bias=self.conv.bias if self.config.causal_conv_bias else None,
|
162 |
+
)
|
163 |
+
return y, (conv_state,)
|
protxlstm/xlstm/components/feedforward.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) NXAI GmbH and its affiliates 2024
|
2 |
+
# Maximilian Beck
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import Callable, Literal
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
from ..utils import UpProjConfigMixin
|
10 |
+
from .init import small_init_init_, wang_init_
|
11 |
+
|
12 |
+
_act_fn_registry = {
|
13 |
+
"gelu": nn.functional.gelu,
|
14 |
+
"relu": nn.functional.relu,
|
15 |
+
"relu^2": lambda x: torch.square(nn.functional.relu(x)),
|
16 |
+
"sigmoid": nn.functional.sigmoid,
|
17 |
+
"swish": nn.functional.silu,
|
18 |
+
"selu": nn.functional.selu,
|
19 |
+
}
|
20 |
+
|
21 |
+
|
22 |
+
def get_act_fn(act_fn_name: str) -> Callable[[torch.Tensor], torch.Tensor]:
|
23 |
+
if act_fn_name in _act_fn_registry:
|
24 |
+
return _act_fn_registry[act_fn_name]
|
25 |
+
else:
|
26 |
+
assert (
|
27 |
+
False
|
28 |
+
), f'Unknown activation function name "{act_fn_name}". Available activation functions are: {str(_act_fn_cls_registry.keys())}'
|
29 |
+
|
30 |
+
|
31 |
+
@dataclass
|
32 |
+
class FeedForwardConfig(UpProjConfigMixin):
|
33 |
+
proj_factor: float = 1.3
|
34 |
+
act_fn: str = "gelu"
|
35 |
+
embedding_dim: int = -1
|
36 |
+
dropout: float = 0.0
|
37 |
+
bias: bool = False
|
38 |
+
ff_type: Literal["ffn_gated"] = "ffn_gated"
|
39 |
+
|
40 |
+
_num_blocks: int = 1
|
41 |
+
|
42 |
+
def __post_init__(self):
|
43 |
+
self._set_proj_up_dim(embedding_dim=self.embedding_dim)
|
44 |
+
assert self.act_fn in _act_fn_registry, f"Unknown activation function {self.act_fn}"
|
45 |
+
|
46 |
+
|
47 |
+
class GatedFeedForward(nn.Module):
|
48 |
+
config_class = FeedForwardConfig
|
49 |
+
|
50 |
+
def __init__(self, config: FeedForwardConfig):
|
51 |
+
super().__init__()
|
52 |
+
self.config = config
|
53 |
+
|
54 |
+
self.proj_up = nn.Linear(
|
55 |
+
in_features=self.config.embedding_dim,
|
56 |
+
out_features=2 * self.config._proj_up_dim,
|
57 |
+
bias=self.config.bias,
|
58 |
+
)
|
59 |
+
self.proj_down = nn.Linear(
|
60 |
+
in_features=self.config._proj_up_dim,
|
61 |
+
out_features=self.config.embedding_dim,
|
62 |
+
bias=self.config.bias,
|
63 |
+
)
|
64 |
+
|
65 |
+
self.act_fn = get_act_fn(self.config.act_fn)
|
66 |
+
|
67 |
+
self.dropout = nn.Dropout(self.config.dropout)
|
68 |
+
self.reset_parameters()
|
69 |
+
|
70 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
71 |
+
gate_preact, up_proj = self.proj_up(x).split(self.config._proj_up_dim, dim=-1)
|
72 |
+
x = self.dropout(self.proj_down(self.act_fn(gate_preact) * up_proj))
|
73 |
+
return x
|
74 |
+
|
75 |
+
def reset_parameters(self):
|
76 |
+
small_init_init_(self.proj_up.weight, dim=self.config.embedding_dim)
|
77 |
+
if self.proj_up.bias is not None:
|
78 |
+
nn.init.zeros_(self.proj_up.bias)
|
79 |
+
wang_init_(self.proj_down.weight, dim=self.config.embedding_dim, num_blocks=self.config._num_blocks)
|
80 |
+
if self.proj_down.bias is not None:
|
81 |
+
nn.init.zeros_(self.proj_down.bias)
|
82 |
+
|
83 |
+
|
84 |
+
def create_feedforward(config: FeedForwardConfig) -> nn.Module:
|
85 |
+
if config.ff_type == "ffn_gated":
|
86 |
+
return GatedFeedForward(config)
|
87 |
+
else:
|
88 |
+
raise ValueError(f"Unknown feedforward type {config.ff_type}")
|
protxlstm/xlstm/components/init.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) NXAI GmbH and its affiliates 2024
|
2 |
+
# Maximilian Beck
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
def bias_linspace_init_(param: torch.Tensor, start: float = 3.4, end: float = 6.0) -> torch.Tensor:
|
9 |
+
"""Linearly spaced bias init across dimensions."""
|
10 |
+
assert param.dim() == 1, f"param must be 1-dimensional (typically a bias), got {param.dim()}"
|
11 |
+
n_dims = param.shape[0]
|
12 |
+
init_vals = torch.linspace(start, end, n_dims)
|
13 |
+
with torch.no_grad():
|
14 |
+
param.copy_(init_vals)
|
15 |
+
return param
|
16 |
+
|
17 |
+
|
18 |
+
def small_init_init_(param: torch.Tensor, dim: int) -> torch.Tensor:
|
19 |
+
"""Fills the input Tensor with values according to the method described in Transformers without Tears: Improving
|
20 |
+
the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2019), using a normal distribution.
|
21 |
+
Adopted from https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py.
|
22 |
+
"""
|
23 |
+
std = math.sqrt(2 / (5 * dim))
|
24 |
+
torch.nn.init.normal_(param, mean=0.0, std=std)
|
25 |
+
return param
|
26 |
+
|
27 |
+
|
28 |
+
def wang_init_(param: torch.Tensor, dim: int, num_blocks: int):
|
29 |
+
"""Adopted from https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py."""
|
30 |
+
std = 2 / num_blocks / math.sqrt(dim)
|
31 |
+
torch.nn.init.normal_(param, mean=0.0, std=std)
|
32 |
+
return param
|
protxlstm/xlstm/components/linear_headwise.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) NXAI GmbH and its affiliates 2024
|
2 |
+
# Maximilian Beck, Korbininan Pöppel
|
3 |
+
from dataclasses import dataclass
|
4 |
+
|
5 |
+
from math import sqrt
|
6 |
+
import torch
|
7 |
+
|
8 |
+
# from einops import einsum, rearrange
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class LinearHeadwiseExpandConfig:
|
14 |
+
in_features: int = 0
|
15 |
+
# this is the number of heads that the in_features are split into
|
16 |
+
# if num_heads=1, this is a normal linear layer
|
17 |
+
# if num_heads>1, the in_features are split into num_heads and each head is projected separately
|
18 |
+
# if num_heads=in_features, each feature is projected separately
|
19 |
+
num_heads: int = -1
|
20 |
+
expand_factor_up: float = 1
|
21 |
+
|
22 |
+
# this is internally computed
|
23 |
+
# but can be overwritten if you want to use a different output dimension
|
24 |
+
# if > 0 the expand factor is ignored
|
25 |
+
_out_features: int = -1
|
26 |
+
|
27 |
+
bias: bool = True
|
28 |
+
trainable_weight: bool = True
|
29 |
+
trainable_bias: bool = True
|
30 |
+
|
31 |
+
def __post_init__(self):
|
32 |
+
assert self.num_heads > 0, "num_heads must be set"
|
33 |
+
assert self.num_heads <= self.in_features, "num_heads must be <= in_features"
|
34 |
+
assert (
|
35 |
+
self.in_features % self.num_heads == 0
|
36 |
+
), "in_features must be a multiple of num_heads"
|
37 |
+
|
38 |
+
if self._out_features < 0:
|
39 |
+
self._out_features = round(self.expand_factor_up * self.in_features)
|
40 |
+
|
41 |
+
|
42 |
+
class LinearHeadwiseExpand(nn.Module):
|
43 |
+
"""This is a structured projection layer that projects the input to a higher dimension.
|
44 |
+
It only allows integer up-projection factors, i.e. the output dimension is a multiple of the input dimension.
|
45 |
+
"""
|
46 |
+
|
47 |
+
config_class = LinearHeadwiseExpandConfig
|
48 |
+
|
49 |
+
def __init__(self, config: LinearHeadwiseExpandConfig):
|
50 |
+
super().__init__()
|
51 |
+
self.config = config
|
52 |
+
in_features = self.config.in_features
|
53 |
+
num_heads = self.config.num_heads
|
54 |
+
out_features_per_head = config._out_features // num_heads
|
55 |
+
self.weight = nn.Parameter(
|
56 |
+
torch.empty(num_heads, out_features_per_head, in_features // num_heads),
|
57 |
+
requires_grad=config.trainable_weight,
|
58 |
+
)
|
59 |
+
if config.bias:
|
60 |
+
self.bias = nn.Parameter(
|
61 |
+
torch.empty(config._out_features), requires_grad=config.trainable_bias
|
62 |
+
)
|
63 |
+
else:
|
64 |
+
self.bias = None
|
65 |
+
self.reset_parameters()
|
66 |
+
|
67 |
+
def reset_parameters(self, **kwargs):
|
68 |
+
# small init
|
69 |
+
nn.init.normal_(
|
70 |
+
self.weight.data, mean=0.0, std=sqrt(2 / 5 / self.weight.shape[-1])
|
71 |
+
)
|
72 |
+
if self.bias is not None:
|
73 |
+
nn.init.zeros_(self.bias.data)
|
74 |
+
|
75 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
76 |
+
shape = x.shape
|
77 |
+
x = x.view(*shape[:-1], self.config.num_heads, -1)
|
78 |
+
x = torch.einsum("...hd,hod->...ho", x, self.weight)
|
79 |
+
x = x.reshape(*shape[:-1], -1)
|
80 |
+
if self.bias is not None:
|
81 |
+
x = x + self.bias
|
82 |
+
return x
|
83 |
+
|
84 |
+
def extra_repr(self):
|
85 |
+
return (
|
86 |
+
f"in_features={self.config.in_features}, "
|
87 |
+
f"num_heads={self.config.num_heads}, "
|
88 |
+
f"expand_factor_up={self.config.expand_factor_up}, "
|
89 |
+
f"bias={self.config.bias}, "
|
90 |
+
f"trainable_weight={self.config.trainable_weight}, "
|
91 |
+
f"trainable_bias={self.config.trainable_bias}, "
|
92 |
+
)
|
protxlstm/xlstm/components/ln.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) NXAI GmbH and its affiliates 2024
|
2 |
+
# Maximilian Beck, Korbinian Pöppel
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
|
8 |
+
class LayerNorm(nn.Module):
|
9 |
+
"""LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False."""
|
10 |
+
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
ndim: int = -1,
|
14 |
+
weight: bool = True,
|
15 |
+
bias: bool = False,
|
16 |
+
eps: float = 1e-5,
|
17 |
+
residual_weight: bool = True,
|
18 |
+
):
|
19 |
+
super().__init__()
|
20 |
+
self.weight = nn.Parameter(torch.zeros(ndim)) if weight else None
|
21 |
+
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
|
22 |
+
self.eps = eps
|
23 |
+
self.residual_weight = residual_weight
|
24 |
+
self.ndim = ndim
|
25 |
+
self.reset_parameters()
|
26 |
+
|
27 |
+
@property
|
28 |
+
def weight_proxy(self) -> torch.Tensor:
|
29 |
+
if self.weight is None:
|
30 |
+
return None
|
31 |
+
if self.residual_weight:
|
32 |
+
return 1.0 + self.weight
|
33 |
+
else:
|
34 |
+
return self.weight
|
35 |
+
|
36 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
37 |
+
return F.layer_norm(
|
38 |
+
input, normalized_shape=(self.ndim,), weight=self.weight_proxy, bias=self.bias, eps=self.eps
|
39 |
+
)
|
40 |
+
|
41 |
+
def reset_parameters(self):
|
42 |
+
if self.weight_proxy is not None:
|
43 |
+
if self.residual_weight:
|
44 |
+
nn.init.zeros_(self.weight)
|
45 |
+
else:
|
46 |
+
nn.init.ones_(self.weight)
|
47 |
+
if self.bias is not None:
|
48 |
+
nn.init.zeros_(self.bias)
|
49 |
+
|
50 |
+
|
51 |
+
class MultiHeadLayerNorm(LayerNorm):
|
52 |
+
|
53 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
54 |
+
assert input.dim() == 4, "Input must be 4D tensor (B, NH, S, DH)"
|
55 |
+
B, NH, S, DH = input.shape
|
56 |
+
|
57 |
+
gn_in_1 = input.transpose(1, 2) # (B, S, NH, DH)
|
58 |
+
gn_in_2 = gn_in_1.reshape(B * S, NH * DH) # (B * S, NH * DH)
|
59 |
+
out = F.group_norm(
|
60 |
+
gn_in_2,
|
61 |
+
num_groups=NH,
|
62 |
+
weight=self.weight_proxy,
|
63 |
+
bias=self.bias,
|
64 |
+
eps=self.eps,
|
65 |
+
)
|
66 |
+
# (B * S), (NH * DH) -> (B, S, NH, DH) -> (B, NH, S, DH)
|
67 |
+
out = out.view(B, S, NH, DH).transpose(1, 2)
|
68 |
+
return out
|
protxlstm/xlstm/components/rotary_position.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Tuple
|
3 |
+
import math
|
4 |
+
|
5 |
+
|
6 |
+
def compute_freqs_cis(t: torch.Tensor, head_dim: int, theta: float = 10_000.0):
|
7 |
+
freqs = theta ** (-torch.arange(0, head_dim, 2).float() / head_dim)
|
8 |
+
freqs = t.unsqueeze(-1) * freqs.to(t.device) # type: ignore
|
9 |
+
freqs_cos = torch.cos(freqs) # real part
|
10 |
+
freqs_sin = torch.sin(freqs) # imaginary part
|
11 |
+
return freqs_cos, freqs_sin
|
12 |
+
|
13 |
+
|
14 |
+
def apply_rotary_emb(
|
15 |
+
xq: torch.Tensor,
|
16 |
+
xk: torch.Tensor,
|
17 |
+
freqs_cos: torch.Tensor,
|
18 |
+
freqs_sin: torch.Tensor
|
19 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
20 |
+
|
21 |
+
# reshape xq and xk to match the complex representation
|
22 |
+
xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
|
23 |
+
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
|
24 |
+
|
25 |
+
# apply rotation using real numbers
|
26 |
+
xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
|
27 |
+
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
|
28 |
+
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
|
29 |
+
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
|
30 |
+
|
31 |
+
# flatten last two dimensions
|
32 |
+
xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(-2)
|
33 |
+
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(-2)
|
34 |
+
|
35 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|