Tschoui commited on
Commit
48097f5
·
1 Parent(s): 66581ec

Migrate application to hugginface

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +4 -0
  2. README.md +93 -14
  3. app.py +151 -0
  4. frontend/constants.py +41 -0
  5. prot_xlstm_env.yml +39 -0
  6. protxlstm/__init__.py +1 -0
  7. protxlstm/applications/__init__.py +0 -0
  8. protxlstm/applications/fitness_prediction.py +214 -0
  9. protxlstm/applications/generation_utils/create_sequence_df.py +85 -0
  10. protxlstm/applications/generation_utils/score_hamming.py +80 -0
  11. protxlstm/applications/generation_utils/score_hmmer.py +102 -0
  12. protxlstm/applications/generation_utils/score_structure.py +55 -0
  13. protxlstm/applications/msa_sampler.py +196 -0
  14. protxlstm/applications/sample_sequences.py +200 -0
  15. protxlstm/applications/score_sequences.py +58 -0
  16. protxlstm/checkpoints/small/config.json +1 -0
  17. protxlstm/checkpoints/small/optimizer.pt +3 -0
  18. protxlstm/checkpoints/small/pytorch_model.bin +3 -0
  19. protxlstm/checkpoints/small/rng_state.pth +3 -0
  20. protxlstm/checkpoints/small/scheduler.pt +3 -0
  21. protxlstm/checkpoints/small/trainer_state.json +0 -0
  22. protxlstm/data.py +60 -0
  23. protxlstm/dataloaders.py +249 -0
  24. protxlstm/fim.py +203 -0
  25. protxlstm/generation.py +384 -0
  26. protxlstm/index.html +16 -0
  27. protxlstm/mamba_utils_generation.py +382 -0
  28. protxlstm/models/__init__.py +0 -0
  29. protxlstm/models/llama.py +342 -0
  30. protxlstm/models/mamba.py +833 -0
  31. protxlstm/models/xlstm.py +180 -0
  32. protxlstm/plot_utils.py +26 -0
  33. protxlstm/train.py +338 -0
  34. protxlstm/trainer.py +123 -0
  35. protxlstm/utils.py +482 -0
  36. protxlstm/xlstm/__init__.py +6 -0
  37. protxlstm/xlstm/blocks/__init__.py +0 -0
  38. protxlstm/xlstm/blocks/mlstm/__init__.py +1 -0
  39. protxlstm/xlstm/blocks/mlstm/backends.py +314 -0
  40. protxlstm/xlstm/blocks/mlstm/block.py +27 -0
  41. protxlstm/xlstm/blocks/mlstm/cell.py +212 -0
  42. protxlstm/xlstm/blocks/mlstm/layer.py +217 -0
  43. protxlstm/xlstm/blocks/xlstm_block.py +111 -0
  44. protxlstm/xlstm/components/__init__.py +0 -0
  45. protxlstm/xlstm/components/conv.py +163 -0
  46. protxlstm/xlstm/components/feedforward.py +88 -0
  47. protxlstm/xlstm/components/init.py +32 -0
  48. protxlstm/xlstm/components/linear_headwise.py +92 -0
  49. protxlstm/xlstm/components/ln.py +68 -0
  50. protxlstm/xlstm/components/rotary_position.py +35 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ */*.pyc
2
+ *.pcy
3
+ */__pycache__
4
+ __pycache__
README.md CHANGED
@@ -1,14 +1,93 @@
1
- ---
2
- title: Prot Xlstm Variant Fitness
3
- emoji: ⚡
4
- colorFrom: indigo
5
- colorTo: blue
6
- sdk: streamlit
7
- sdk_version: 1.43.2
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- short_description: 'This application enables to inspect mutational effects on a '
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # prot-xLSTM-app
2
+
3
+
4
+
5
+ ## Getting started
6
+
7
+ To make it easy for you to get started with GitLab, here's a list of recommended next steps.
8
+
9
+ Already a pro? Just edit this README.md and make it your own. Want to make it easy? [Use the template at the bottom](#editing-this-readme)!
10
+
11
+ ## Add your files
12
+
13
+ - [ ] [Create](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#create-a-file) or [upload](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#upload-a-file) files
14
+ - [ ] [Add files using the command line](https://docs.gitlab.com/ee/gitlab-basics/add-file.html#add-a-file-using-the-command-line) or push an existing Git repository with the following command:
15
+
16
+ ```
17
+ cd existing_repo
18
+ git remote add origin https://git.bioinf.jku.at/chemoinformatics/prot-xlstm-app.git
19
+ git branch -M main
20
+ git push -uf origin main
21
+ ```
22
+
23
+ ## Integrate with your tools
24
+
25
+ - [ ] [Set up project integrations](https://git.bioinf.jku.at/chemoinformatics/prot-xlstm-app/-/settings/integrations)
26
+
27
+ ## Collaborate with your team
28
+
29
+ - [ ] [Invite team members and collaborators](https://docs.gitlab.com/ee/user/project/members/)
30
+ - [ ] [Create a new merge request](https://docs.gitlab.com/ee/user/project/merge_requests/creating_merge_requests.html)
31
+ - [ ] [Automatically close issues from merge requests](https://docs.gitlab.com/ee/user/project/issues/managing_issues.html#closing-issues-automatically)
32
+ - [ ] [Enable merge request approvals](https://docs.gitlab.com/ee/user/project/merge_requests/approvals/)
33
+ - [ ] [Set auto-merge](https://docs.gitlab.com/ee/user/project/merge_requests/merge_when_pipeline_succeeds.html)
34
+
35
+ ## Test and Deploy
36
+
37
+ Use the built-in continuous integration in GitLab.
38
+
39
+ - [ ] [Get started with GitLab CI/CD](https://docs.gitlab.com/ee/ci/quick_start/)
40
+ - [ ] [Analyze your code for known vulnerabilities with Static Application Security Testing (SAST)](https://docs.gitlab.com/ee/user/application_security/sast/)
41
+ - [ ] [Deploy to Kubernetes, Amazon EC2, or Amazon ECS using Auto Deploy](https://docs.gitlab.com/ee/topics/autodevops/requirements.html)
42
+ - [ ] [Use pull-based deployments for improved Kubernetes management](https://docs.gitlab.com/ee/user/clusters/agent/)
43
+ - [ ] [Set up protected environments](https://docs.gitlab.com/ee/ci/environments/protected_environments.html)
44
+
45
+ ***
46
+
47
+ # Editing this README
48
+
49
+ When you're ready to make this README your own, just edit this file and use the handy template below (or feel free to structure it however you want - this is just a starting point!). Thanks to [makeareadme.com](https://www.makeareadme.com/) for this template.
50
+
51
+ ## Suggestions for a good README
52
+
53
+ Every project is different, so consider which of these sections apply to yours. The sections used in the template are suggestions for most open source projects. Also keep in mind that while a README can be too long and detailed, too long is better than too short. If you think your README is too long, consider utilizing another form of documentation rather than cutting out information.
54
+
55
+ ## Name
56
+ Choose a self-explaining name for your project.
57
+
58
+ ## Description
59
+ Let people know what your project can do specifically. Provide context and add a link to any reference visitors might be unfamiliar with. A list of Features or a Background subsection can also be added here. If there are alternatives to your project, this is a good place to list differentiating factors.
60
+
61
+ ## Badges
62
+ On some READMEs, you may see small images that convey metadata, such as whether or not all the tests are passing for the project. You can use Shields to add some to your README. Many services also have instructions for adding a badge.
63
+
64
+ ## Visuals
65
+ Depending on what you are making, it can be a good idea to include screenshots or even a video (you'll frequently see GIFs rather than actual videos). Tools like ttygif can help, but check out Asciinema for a more sophisticated method.
66
+
67
+ ## Installation
68
+ Within a particular ecosystem, there may be a common way of installing things, such as using Yarn, NuGet, or Homebrew. However, consider the possibility that whoever is reading your README is a novice and would like more guidance. Listing specific steps helps remove ambiguity and gets people to using your project as quickly as possible. If it only runs in a specific context like a particular programming language version or operating system or has dependencies that have to be installed manually, also add a Requirements subsection.
69
+
70
+ ## Usage
71
+ Use examples liberally, and show the expected output if you can. It's helpful to have inline the smallest example of usage that you can demonstrate, while providing links to more sophisticated examples if they are too long to reasonably include in the README.
72
+
73
+ ## Support
74
+ Tell people where they can go to for help. It can be any combination of an issue tracker, a chat room, an email address, etc.
75
+
76
+ ## Roadmap
77
+ If you have ideas for releases in the future, it is a good idea to list them in the README.
78
+
79
+ ## Contributing
80
+ State if you are open to contributions and what your requirements are for accepting them.
81
+
82
+ For people who want to make changes to your project, it's helpful to have some documentation on how to get started. Perhaps there is a script that they should run or some environment variables that they need to set. Make these steps explicit. These instructions could also be useful to your future self.
83
+
84
+ You can also document commands to lint the code or run tests. These steps help to ensure high code quality and reduce the likelihood that the changes inadvertently break something. Having instructions for running tests is especially helpful if it requires external setup, such as starting a Selenium server for testing in a browser.
85
+
86
+ ## Authors and acknowledgment
87
+ Show your appreciation to those who have contributed to the project.
88
+
89
+ ## License
90
+ For open source projects, say how it is licensed.
91
+
92
+ ## Project status
93
+ If you have run out of energy or time for your project, put a note at the top of the README saying that development has slowed down or stopped completely. Someone may choose to fork your project or volunteer to step in as a maintainer or owner, allowing your project to keep going. You can also make an explicit request for maintainers.
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["CUDA_VISIBLE_DEVICES"] = "" #disable cuda
3
+
4
+ import streamlit as st
5
+ import numpy as np
6
+ import torch
7
+ import time
8
+ from Bio import SeqIO
9
+
10
+ from protxlstm.applications.fitness_prediction import single_mutation_landscape_xlstm, create_mutation_df
11
+ from protxlstm.applications.msa_sampler import sample_msa
12
+ from protxlstm.models.xlstm import xLSTMLMHeadModel
13
+ from protxlstm.utils import load_model
14
+ import io
15
+
16
+ from frontend.constants import info_text, citation_text
17
+
18
+ DEFAULT_SEQUENCE = "MTARGLALGLLLLLLCPAQVFSQSCVWYGECGIAYGDKRYNCEYSGPPKPLPKDGYDLVQELCPGFFFGNVSLCCDVRQLQTLKDNLQLPLQFLSRCPSCFYNLLNLFCELTCSPRQSQFLNVTATEDYVDPVTNQTKTNVKELQYYVGQSFANAMYNACRDVEAPSSNDKALGLLCGKDADACNATNWIEYMFNKDNGQAPFTITPVFSDFPVHGMEPMNNATKGCDESVDEVTAPCSCQDCSIVCGPKPQPPPPPAPWTILGLDAMYVIMWITYMAFLLVFFGAFFAVWCYRKRYFVSEYTPIDSNIAFSVNASDKGEASCCDPVSAAFEGCLRRLFTRWGSFCVRNPGCVIFFSLVFITACSSGLVFVRVTTNPVDLWSAPSSQARLEKEYFDQHFGPFFRTEQLIIRAPLTDKHIYQPYPSGADVPFGPPLDIQILHQVLDLQIAIENITASYDNETVTLQDICLAPLSPYNTNCTILSVLNYFQNSHSVLDHKKGDDFFVYADYHTHFLYCVRAPASLNDTSLLHDPCLGTFGGPVFPWLVLGGYDDQNYNNATALVITFPVNNYYNDTEKLQRAQAWEKEFINFVKNYKNPNLTISFTAERSIEDELNRESDSDVFTVVISYAIMFLYISLALGHMKSCRRLLVDSKVSLGIAGILIVLSSVACSLGVFSYIGLPLTLIVIEVIPFLVLAVGVDNIFILVQAYQRDERLQGETLDQQLGRVLGEVAPSMFLSSFSETVAFFLGALSVMPAVHTFSLFAGLAVFIDFLLQITCFVSLLGLDIKRQEKNRLDIFCCVRGAEDGTSVQASESCLFRFFKNSYSPLLLKDWMRPIVIAIFVGVLSFSIAVLNKVDIGLDQSLSMPDDSYMVDYFKSISQYLHAGPPVYFVLEEGHDYTSSKGQNMVCGGMGCNNDSLVQQIFNAAQLDNYTRIGFAPSSWIDDYFDWVKPQSSCCRVDNITDQFCNASVVDPACVRCRPLTPEGKQRPQGGDFMRFLPMFLSDNPNPKCGKGGHAAYSSAVNILLGHGTRVGATYFMTYHTVLQTSADFIDALKKARLIASNVTETMGINGSAYRVFPYSVFYVFYEQYLTIIDDTIFNLGVSLGAIFLVTMVLLGCELWSAVIMCATIAMVLVNMFGVMWLWGISLNAVSLVNLVMSCGISVEFCSHITRAFTVSMKGSRVERAEEALAHMGSSVFSGITLTKFGGIVVLAFAKSQIFQIFYFRMYLAMVLLGATHGLIFLPVLLSYIGPSVNKAKSCATEERYKGTERERLLNF"
19
+
20
+ mutation_positions = []
21
+ msa_file = None
22
+
23
+ if 'fitness_done' not in st.session_state:
24
+ st.session_state.fitness_done = False
25
+ st.session_state.mutations = None
26
+ st.session_state.fitness_duration = None
27
+ st.session_state.target_sequence = ""
28
+ st.session_state.context_sequences = []
29
+ st.session_state.num_context_sequences = 25
30
+
31
+ def run_model():
32
+ try:
33
+ st.session_state.fitness_duration = time.time()
34
+ checkpoint = "protxlstm/checkpoints/small"
35
+ num_context_tokens = 2**15
36
+ df_mutations = create_mutation_df(st.session_state.target_sequence, mutation_positions)
37
+ if msa_file != None and st.session_state.num_context_sequences != 0:
38
+ def load_sequences_from_msa_file(file_obj):
39
+ text_io = io.TextIOWrapper(file_obj, encoding="utf-8")
40
+ sequences = [str(record.seq) for record in SeqIO.parse(text_io, "fasta")]
41
+ return sequences
42
+ msa_sequences = [msa.upper() for msa in load_sequences_from_msa_file(msa_file)]
43
+ st.session_state.context_sequences = sample_msa(msa_sequences, max_context_sequences=st.session_state.num_context_sequences, context_length=num_context_tokens)
44
+ st.session_state.context_sequences += [st.session_state.target_sequence]
45
+
46
+ config_update_kwargs = {
47
+ "mlstm_backend": "chunkwise_variable",
48
+ "mlstm_chunksize": 1024,
49
+ "mlstm_return_last_state": True}
50
+
51
+ model = load_model(
52
+ checkpoint,
53
+ model_class=xLSTMLMHeadModel,
54
+ device='cpu',
55
+ dtype=torch.bfloat16,
56
+ **config_update_kwargs,
57
+ )
58
+ model = model.eval()
59
+ st.session_state.mutations, _ = single_mutation_landscape_xlstm(model, df_mutations, st.session_state.context_sequences, chunk_chunk_size=2**15)
60
+ print("fitness_done")
61
+ st.session_state.fitness_done = True
62
+ st.session_state.fitness_duration = time.time() - st.session_state.fitness_duration
63
+ except Exception as e:
64
+ print(e)
65
+
66
+ # PAGE STYLE (mainly for custom aa selection)
67
+ st.set_page_config(layout="wide")
68
+ st.markdown(
69
+ """
70
+ <style>
71
+ .stButtonGroup button {
72
+ padding: 0px 1px 0px 1px !important;
73
+ border: 0 solid transparent !important;
74
+ min-height: 0px !important;
75
+ line-height: 120% !important;
76
+ height: auto !important;
77
+ }
78
+ .stSidebar {
79
+ width: 600px !important;
80
+ }
81
+ </style>
82
+ """,
83
+ unsafe_allow_html=True
84
+ )
85
+
86
+
87
+ with st.sidebar:
88
+ st.title("Prot-xLSTM Variant Fitness")
89
+
90
+ # LOAD SEQUENCE
91
+ st.session_state.target_sequence = st.text_area(
92
+ "Target protein sequence",
93
+ placeholder=DEFAULT_SEQUENCE,
94
+ value=st.session_state.target_sequence
95
+ )
96
+ if st.button("Load sequence"):
97
+ if st.session_state.target_sequence == "":
98
+ st.session_state.target_sequence = DEFAULT_SEQUENCE
99
+
100
+ # MANAGE CONTEXT SEQUENCES
101
+ context_type = st.selectbox(
102
+ "Choose how to enter context",
103
+ ("Enter manually", "Use MSA file"),
104
+ index=None,
105
+ placeholder="Choose context",
106
+ )
107
+ if context_type == 'Enter manually':
108
+ context_sequence_str = st.text_area(
109
+ "Enter context protein sequences (seperated by comma)",
110
+ placeholder=DEFAULT_SEQUENCE,
111
+ )
112
+ st.session_state.context_sequences = context_sequence_str.split(",") + [st.session_state.target_sequence]
113
+ elif context_type == 'Use MSA file':
114
+ msa_file = st.file_uploader("Choose MSA file")
115
+ st.session_state.num_context_sequences = st.number_input("How many of these sequences should be used?", min_value=0, step=1, value=25)
116
+ else:
117
+ st.session_state.context_sequences = [st.session_state.target_sequence]
118
+
119
+ if st.session_state.target_sequence != "":
120
+ with st.container():
121
+
122
+ # MUTATION POSITION SELECTION
123
+ aas = list(st.session_state.target_sequence)
124
+ mutation_indices = np.arange(1, len(aas)+1)
125
+ mutation_positions = st.segmented_control(
126
+ "Choose mutation positions (click to select)", mutation_indices, selection_mode="multi", format_func=lambda i: aas[i-1],
127
+ )
128
+ st.button("Check Fitness", on_click=run_model)
129
+
130
+ # DISPLAY RESULTS
131
+ if st.session_state.fitness_done:
132
+ st.metric(label="Running time", value=f"{st.session_state.fitness_duration:.2f} sec.")
133
+ selected_pos = st.selectbox(
134
+ "Visualized mutation position",
135
+ st.session_state.mutations['position'].unique()
136
+ )
137
+ selected_data = st.session_state.mutations.where(st.session_state.mutations['position'] == selected_pos)
138
+ st.bar_chart(selected_data, x='mutation', y='effect', horizontal=True)
139
+ st.dataframe(st.session_state.mutations, use_container_width=True)
140
+
141
+ # TUTORIAL
142
+ with st.expander("Info & Tutorial", expanded=True):
143
+ st.subheader("Tutorial")
144
+ st.markdown("**1.** Choose a target protein sequence (leave empty to use a sample sequence) and press 'Load Sequence'")
145
+ st.markdown("**2.** Enter or upload you context sequences. (leave empty to use no context)")
146
+ st.markdown("**3.** Choose which amino acids to mutate (click on the AA's to select them) and press 'Check Fitness'")
147
+ st.subheader("General Information")
148
+ st.markdown(info_text, unsafe_allow_html=True)
149
+ st.markdown("")
150
+ st.subheader("Cite us / BibTex")
151
+ st.code(citation_text, language=None)
frontend/constants.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ info_text = ("""
2
+ <div style="text-align: justify">
3
+ This application enables to inspect mutational effects on a
4
+ predefined protein sequence.<br>
5
+ <br>
6
+ </div>
7
+
8
+ <div style="text-align: justify">
9
+ It is built on the Prot-xLSTM backbone model, an xLSTM model specifically
10
+ trained on protein sequences. Prot-xLSTM was trained using the
11
+ Fill-In-the-Middle (FIM) objective, which allows it to perform sequence
12
+ inpainting. Additionally, the model can be provided with a potentially
13
+ large set of homologous sequences to enhance its predictions.<br>
14
+ <br>
15
+ </div>
16
+
17
+ <div style="text-align: justify">
18
+ For further information please refer, to: <a href="https://openreview.net/forum?id=IjbXZdugdj" target="_blank">https://openreview.net/forum?id=IjbXZdugdj</a>. <br>
19
+ <br>
20
+
21
+ This Hugging Face application is based on the following GitHub repository:
22
+ <a href="https://github.com/ml-jku/Prot-xLSTM?tab=readme-ov-file" target="_blank">https://github.com/ml-jku/Prot-xLSTM?tab=readme-ov-file</a>. <br>
23
+ The streamlit application was developed by Elias Bürger.
24
+ </div>
25
+
26
+ <div style="text-align: justify">
27
+ Please cite us as follows: <br>
28
+ </div>
29
+ """)
30
+ citation_text = """
31
+ @misc{
32
+ schmidinger2024bioxlstmgenerativemodelingrepresentation,
33
+ title={Bio-xLSTM: Generative modeling, representation and in-context learning of biological and chemical sequences},
34
+ author={Niklas Schmidinger and Lisa Schneckenreiter and Philipp Seidl and Johannes Schimunek and Pieter-Jan Hoedt and Johannes Brandstetter and Andreas Mayr and Sohvi Luukkonen and Sepp Hochreiter and Günter Klambauer},
35
+ year={2024},
36
+ eprint={2411.04165},
37
+ archivePrefix={arXiv},
38
+ primaryClass={q-bio.BM},
39
+ url={https://arxiv.org/abs/2411.04165},
40
+ }
41
+ """
prot_xlstm_env.yml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: prot_xlstm_app
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - conda-forge
6
+ - defaults
7
+ dependencies:
8
+ - cuda=12.1
9
+ - cuda-nvcc=12.1
10
+ - gxx_linux-64=11.2.0
11
+ - python=3.11
12
+ - pip
13
+ - pytorch=2.2.0
14
+ - pytorch-cuda=12.1
15
+ - cmake
16
+ - ninja
17
+ - pip:
18
+ - accelerate>=0.26.0
19
+ - biopython #==1.83
20
+ - bottleneck #==1.4.2
21
+ - dacite #==1.8.1
22
+ - ipykernel #==6.29.3
23
+ - mamba_ssm==1.2.0
24
+ - matplotlib #==3.8.4
25
+ - numpy<2.0 #==1.26.4
26
+ - omegaconf #==2.3.0
27
+ - pandas #==2.2.2
28
+ - pyhmmer #==0.10.15
29
+ - rich #==13.7.1
30
+ - scipy #==1.13.0
31
+ - seaborn #==0.13.2
32
+ - torchmetrics #==1.2.1
33
+ - tqdm #==4.66.4
34
+ - transformers==4.44.2
35
+ - tueplots #==0.0.17
36
+ - wandb #==0.17.0
37
+ - streamlit #==1.43.2
38
+
39
+
protxlstm/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # __version__ = "0.0.1"
protxlstm/applications/__init__.py ADDED
File without changes
protxlstm/applications/fitness_prediction.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+
5
+ import torch
6
+ from tqdm.auto import tqdm
7
+
8
+ from protxlstm.applications.msa_sampler import MSASampler
9
+ from protxlstm.generation import generate_sequence
10
+ from protxlstm.utils import AA_TO_ID, tokenizer, ID_TO_AA
11
+
12
+
13
+ def precompute_context_state(model, sequences, chunk_chunk_size=2**15):
14
+ """
15
+ Precompute the output states for a fixed context that remains the same across generations.
16
+ Returns the hidden states to continue generation later.
17
+ """
18
+ device = next(model.parameters()).device
19
+
20
+ input_ids, pos_ids = prepare_context(sequences)
21
+ state = None
22
+
23
+ for chunk in range(input_ids.shape[1]//chunk_chunk_size+1):
24
+
25
+ start_idx = chunk*chunk_chunk_size
26
+ end_idx = min((chunk+1)*chunk_chunk_size, input_ids.shape[1])
27
+
28
+ if start_idx == end_idx:
29
+ pass
30
+
31
+ else:
32
+ input_ids_chunk = input_ids[:, start_idx:end_idx].to(device)
33
+ pos_ids_chunk = pos_ids[:, start_idx:end_idx].to(device)
34
+
35
+ with torch.no_grad():
36
+ outputs = model(input_ids=input_ids_chunk,
37
+ position_ids=pos_ids_chunk,
38
+ state=state,
39
+ output_hidden_states=True,
40
+ return_dict=True)
41
+ state = outputs.state
42
+
43
+ # Return the hidden states for reuse
44
+ return state
45
+
46
+ def prepare_context(sequences):
47
+ tokenized_sequences = tokenizer(sequences, concatenate=False)
48
+ pos_ids = torch.cat([torch.arange(0, len(seq), dtype=torch.int64) for seq in tokenized_sequences], 0)[None, :]
49
+ input_ids = torch.cat(tokenized_sequences, 0)[None, :].to(torch.int64)
50
+ return input_ids, pos_ids
51
+
52
+ def prepare_single_mutation_target(target, mut_pos):
53
+
54
+ pos_ids = torch.arange(target.shape[1], dtype=torch.int64)[None,:] # default position ids
55
+ t = torch.ones((target.shape[0], 1), dtype=torch.int64)
56
+ new_target = torch.cat([
57
+ target[:,:mut_pos], # WT sequence until mutated position
58
+ AA_TO_ID["<mask-1>"] * t, # Mask token at the muated position
59
+ target[:,mut_pos+1:], # WT sequence after mutated position
60
+ AA_TO_ID["<eos>"] * t, # End of sequence token
61
+ AA_TO_ID["<mask-1>"] * t, # Mask token
62
+ ], dim=1)
63
+ new_pos_ids = torch.cat([
64
+ pos_ids,
65
+ 0 * t, # end of sequence
66
+ mut_pos * t, # mutation position
67
+ ], dim=1)
68
+
69
+ is_fim_dict = { AA_TO_ID["<mask-1>"] : pos_ids[:,mut_pos].squeeze().item()}
70
+
71
+ return new_target, new_pos_ids, is_fim_dict
72
+
73
+ def single_mutation_landscape_xlstm(model, single_mutations, context_sequences, chunk_chunk_size=2**15):
74
+
75
+ device = next(model.parameters()).device
76
+
77
+ # Tokenize WT target sequence
78
+ wt_tokens = tokenizer([context_sequences[-1]], concatenate=True)
79
+
80
+ # Precompute hidden state of context
81
+ context_state = precompute_context_state(model, context_sequences, chunk_chunk_size=chunk_chunk_size)
82
+
83
+ mutation_positions = sorted(single_mutations.position.unique())
84
+ all_logits = np.zeros((len(mutation_positions), 20))
85
+
86
+ # Iterate over all mutated positions
87
+ for i, pos in tqdm(enumerate(mutation_positions), total=len(mutation_positions), desc="Generating mutational landscape"): # This loop can be parallelized
88
+
89
+ # Prepare target
90
+ wt_aa_id = wt_tokens[0, pos+1].int().item() # wild type AA index
91
+ target_tokens, target_pos_ids, _ = prepare_single_mutation_target(wt_tokens, pos+1)
92
+
93
+ with torch.no_grad():
94
+ outputs = model(input_ids=target_tokens.to(device),
95
+ position_ids=target_pos_ids.to(device),
96
+ state=context_state,
97
+ )
98
+
99
+ # Extact logits and compute mutational effect
100
+ logits = outputs.logits.clone().detach() # Raw logits
101
+ logits_mut = logits[0, -1, 4:24].log_softmax(-1) # Log-softmax for mutation prediction: (4-24) correspond to natural NNs
102
+ mut_effects = logits_mut - logits_mut[wt_aa_id - 4] # Subtract log probability of ground truth
103
+ all_logits[i,:] = logits_mut.cpu()
104
+ single_mutations.loc[single_mutations.position == pos, 'effect'] = single_mutations.loc[single_mutations.position == pos, 'mutation_idx'].apply(lambda x : mut_effects[x-4].item())
105
+
106
+ return single_mutations, all_logits
107
+
108
+ def single_mutation_landscape_mamba(model, single_mutations, context_sequences):
109
+
110
+ # Prepare context sequences
111
+ context_tokens, context_pos_ids = prepare_context(context_sequences)
112
+
113
+ # Tokenize WT target sequence
114
+ wt_tokens = tokenizer([context_sequences[-1]], concatenate=True)
115
+
116
+ mutation_positions = sorted(single_mutations.position.unique())
117
+ all_logits = np.zeros((len(mutation_positions), 20))
118
+
119
+ # Iterate over all mutated positions
120
+ for i, pos in tqdm(enumerate(mutation_positions), total=len(mutation_positions), desc="Generating mutational landscape"): # This loop can be parallelized
121
+
122
+ # Prepare target
123
+ wt_aa_id = wt_tokens[0, pos+1].int().item() # wild type AA index
124
+ target_tokens, target_pos_ids, is_fim_dict = prepare_single_mutation_target(wt_tokens, pos+1)
125
+
126
+ # Merge context and target
127
+ device = next(model.parameters()).device
128
+ context_tokens = torch.cat([context_tokens, target_tokens], dim=1).to(device)
129
+ context_pos_ids = torch.cat([context_pos_ids, target_pos_ids], dim=1).to(device)
130
+
131
+ # Generate fim-token prediction
132
+ output = generate_sequence(
133
+ model,
134
+ context_tokens,
135
+ position_ids=context_pos_ids,
136
+ is_fim=is_fim_dict,
137
+ max_length=1,
138
+ temperature=1.0,
139
+ top_k=0,
140
+ top_p=0.0,
141
+ return_dict_in_generate=True,
142
+ output_scores=True,
143
+ eos_token_id=AA_TO_ID["<cls>"],
144
+ device=device
145
+ )
146
+
147
+ # Extact logits and compute mutational effect
148
+ logits = torch.tensor(output["scores"]) # Raw logits
149
+ logits_mut = logits[0, 0, 4:24].log_softmax(-1) # Log-softmax for mutation prediction: (4-24) correspond to natural NNs
150
+ mut_effects = logits_mut - logits_mut[wt_aa_id - 4] # Subtract log probability of ground truth
151
+ all_logits[i,:] = logits_mut.cpu()
152
+
153
+ single_mutations.loc[single_mutations.position == pos, 'effect'] = single_mutations.loc[single_mutations.position == pos, 'mutation_idx'].apply(lambda x : mut_effects[x-4].item())
154
+
155
+ return single_mutations, all_logits
156
+
157
+ def single_mutation_landscape_retrieval(single_mutations, msa_sequences, msa_weights_path):
158
+
159
+ # One-hot encode MSA sequences
160
+ msa_tokens = np.array([[AA_TO_ID[aa.upper()] for aa in seq] for seq in msa_sequences])
161
+ one_hot_tokens = np.zeros((len(msa_tokens), len(msa_tokens[0]), 40))
162
+ one_hot_tokens[np.arange(len(msa_tokens))[:, None], np.arange(len(msa_tokens[0])), msa_tokens] = 1
163
+
164
+ #Load/compute weights
165
+ if os.path.exists(msa_weights_path):
166
+ weights = np.load(msa_weights_path)
167
+ else:
168
+ sampler = MSASampler(0.98, 0.7)
169
+ weights = sampler.get_weights(msa_tokens)[1]
170
+ np.save(msa_weights_path, weights)
171
+ assert one_hot_tokens.shape[0] == weights.shape[0]
172
+
173
+ # Apply sequence weights, normalize amino acid probabilities per position, and convert to a PyTorch tensor.
174
+ one_hot_tokens = one_hot_tokens * weights[:, None, None]
175
+ one_hot_tokens = one_hot_tokens.sum(0)
176
+ one_hot_tokens = one_hot_tokens[:, 4:24] + 1 / len(msa_sequences)
177
+ one_hot_tokens_sum = one_hot_tokens.sum(-1)
178
+ one_hot_tokens = one_hot_tokens / one_hot_tokens_sum[:, None]
179
+ one_hot_tokens = torch.tensor(one_hot_tokens).float()
180
+
181
+ # Compute mutational effects
182
+ wild_type = msa_tokens[0]
183
+ logits = one_hot_tokens.log()
184
+ logits = logits - logits[torch.arange(len(logits)), wild_type - 4][:, None]
185
+
186
+ single_mutations['retrieval_effect'] = single_mutations.apply(
187
+ lambda row: logits[row['position'], row['mutation_idx'] - 4].item(), axis=1)
188
+
189
+ return single_mutations
190
+
191
+
192
+ def create_mutation_df(sequence, mutation_positions):
193
+ """
194
+ Generate a DataFrame containing all possible mutations at specified positions in a sequence.
195
+
196
+ Args:
197
+ sequence (str): The original sequence to mutate.
198
+ mutation_positions (list of int): List of positions to mutate (1-based index).
199
+
200
+ Returns:
201
+ pd.DataFrame:
202
+ - 'mutation': formatted mutation string (e.g., 'A10G' for Ala at position 10 to Gly).
203
+ - 'position': 0-based position in the sequence.
204
+ - 'mutation_idx': numeric index for the mutation.
205
+ """
206
+
207
+ AAs = {k: v for k, v in ID_TO_AA.items() if 4 <= k <= 23}
208
+ mutation_data = []
209
+ for position in mutation_positions:
210
+ wt = sequence[position - 1]
211
+ for idx, aa in AAs.items():
212
+ mutation = f"{wt}{position}{aa}"
213
+ mutation_data.append({'mutation': mutation, 'position': position - 1, 'mutation_idx': idx})
214
+ return pd.DataFrame(mutation_data)
protxlstm/applications/generation_utils/create_sequence_df.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pickle
3
+ import pandas as pd
4
+
5
+ from protxlstm.dataloaders import ProteinMemmapDataset
6
+ from protxlstm.utils import decode_sequence, reorder_masked_sequence
7
+
8
+
9
+ def create_sequence_df(model_name, family_idx, parameters_list=None, num_sequences = 100, data_dir="./data/"):
10
+
11
+ #load dataset
12
+ dataset = ProteinMemmapDataset(
13
+ msa_memmap_path=f"{data_dir}open_protein_set_memmap.dat",
14
+ msa_memmap_meta_path=f"{data_dir}open_protein_set_memmap_indices.csv",
15
+ subset_path=f"{data_dir}/cluster_testing_set.txt",
16
+ sample=False,
17
+ max_msa_len=-1,
18
+ reverse=False,
19
+ seed=0,
20
+ troubleshoot=False,
21
+ fim_strategy="multiple_span",
22
+ always_mask=False,
23
+ max_position_embeddings=2048,
24
+ max_seq_position_embeddings=512,
25
+ add_position_ids="1d",
26
+ mask_fraction=0.2,
27
+ max_patches=5,
28
+ )
29
+
30
+ family_id = list(dataset.dataset_meta["msa_id"])[family_idx]
31
+
32
+ if model_name == "natural":
33
+
34
+ data = dataset[family_idx]
35
+ sequence_df = pd.DataFrame(columns=["family", "family_id", "sequence", "sequence_length"])
36
+ tokens = data["input_ids"][None,:]
37
+ all_context = decode_sequence(tokens[0].cpu().numpy())
38
+ list_sequences_msa = [reorder_masked_sequence(elem+"<cls>") for elem in all_context.split("<cls>")[1:-1]]
39
+
40
+ rd_idxs = np.random.choice(len(list_sequences_msa), num_sequences, replace=False)
41
+ natural_sequences = [seq for i, seq in enumerate(list_sequences_msa) if i in rd_idxs]
42
+
43
+ df_dict = {"family": [family_idx]*len(natural_sequences),
44
+ "family_id": [family_id]*len(natural_sequences),
45
+ "sequence": natural_sequences,
46
+ "sequence_length": [len(seq) for seq in natural_sequences]}
47
+
48
+ sequence_df = pd.concat([sequence_df, pd.DataFrame(df_dict)], ignore_index = True)
49
+
50
+ else:
51
+
52
+ sequence_df = pd.DataFrame(columns=["family", "family_id", "n_seqs_ctx", "temperature", "top_k", "top_p", "original_sequence", "sequence", "sequence_length", "perplexity"])
53
+
54
+ if parameters_list is None:
55
+ parameters_list = [(10,1.,10,1.), (10,1.,15,1.), (10,1.,10,0.95), (10,0.9,10,0.95), (10,0.8,10,0.9),
56
+ (100,1.,10,1.), (100,1.,15,1.), (100,1.,10,0.95), (100,0.9,10,0.95), (100,0.8,10,0.9),
57
+ (500,1.,10,1.), (500,1.,15,1.), (500,1.,10,0.95), (500,0.9,10,0.95), (500,0.8,10,0.9),
58
+ (1000,1.,10,1.), (1000,1.,15,1.), (1000,1.,10,0.95), (1000,0.9,10,0.95), (1000,0.8,10,0.9),
59
+ (-1,1.,10,1.), (-1,1.,15,1.), (-1,1.,10,0.95), (-1,0.9,10,0.95), (-1,0.8,10,0.9)]
60
+
61
+ for param in parameters_list:
62
+ n_seqs_ctx, temperature, top_k, top_p = param
63
+
64
+ with open(f"evaluation/generation/generated_sequences/{model_name}/{family_idx}_{param}_{num_sequences}", "rb") as f:
65
+ gen_seqs = pickle.load(f)
66
+
67
+ original_sequences = list(gen_seqs[family_idx][param].keys())
68
+ reordered_sequences = [reorder_masked_sequence(seq) for seq in original_sequences]
69
+ perplexities = [gen_seqs[family_idx][param][seq]["perplexity"] for seq in original_sequences]
70
+ df_dict = {"family": [family_idx]*len(original_sequences),
71
+ "family_id": [family_id]*len(original_sequences),
72
+ "n_seqs_ctx": [n_seqs_ctx]*len(original_sequences),
73
+ "temperature": [temperature]*len(original_sequences),
74
+ "top_k": [top_k]*len(original_sequences),
75
+ "top_p": [top_p]*len(original_sequences),
76
+ "original_sequence": original_sequences,
77
+ "sequence": reordered_sequences,
78
+ "sequence_length": [len(seq) for seq in reordered_sequences],
79
+ "perplexity": perplexities
80
+ }
81
+
82
+ sequence_df = pd.concat([sequence_df, pd.DataFrame(df_dict)], ignore_index = True)
83
+
84
+ return sequence_df
85
+
protxlstm/applications/generation_utils/score_hamming.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from tqdm import tqdm
3
+ import pandas as pd
4
+ from Bio import Align
5
+
6
+ from protxlstm.dataloaders import ProteinMemmapDataset
7
+ from protxlstm.utils import decode_sequence, reorder_masked_sequence
8
+
9
+
10
+ aligner = Align.PairwiseAligner()
11
+ aligner.mode = 'global'
12
+ aligner.match_score = 1
13
+ aligner.mismatch_score = -1
14
+ aligner.open_gap_score = -1
15
+ aligner.extend_gap_score = -1
16
+
17
+ def align_sequences(ref_seq, query_seq, print_alignments=False):
18
+ def hamming_str(s1,s2):
19
+ assert len(s1) == len(s2)
20
+ return sum(np.array(list(s1)) != np.array(list(s2)))/len(s1)
21
+ alignments = aligner.align(ref_seq, query_seq)
22
+ if print_alignments:
23
+ print("Score = %.1f:" % alignments[0].score)
24
+ print(alignments[0])
25
+ return hamming_str(alignments[0][0], alignments[0][1]), alignments[0][0], alignments[0][1]
26
+
27
+
28
+ def score_hamming(sequence_df, family_idx, data_dir = f"./data/"):
29
+
30
+ assert len(set(list(sequence_df["family"]))) == 1 and sequence_df["family"].iloc[0] == family_idx
31
+
32
+ #load dataset
33
+ dataset = ProteinMemmapDataset(
34
+ msa_memmap_path=f"{data_dir}open_protein_set_memmap.dat",
35
+ msa_memmap_meta_path=f"{data_dir}open_protein_set_memmap_indices.csv",
36
+ subset_path=f"{data_dir}/cluster_testing_set.txt",
37
+ sample=False,
38
+ max_msa_len=-1,
39
+ reverse=False,
40
+ seed=0,
41
+ troubleshoot=False,
42
+ fim_strategy="multiple_span",
43
+ always_mask=False,
44
+ max_position_embeddings=2048,
45
+ max_seq_position_embeddings=512,
46
+ add_position_ids="1d",
47
+ mask_fraction=0.2,
48
+ max_patches=5,
49
+ )
50
+
51
+ # Select a sample of the dataset to be the input
52
+ data = dataset[family_idx]
53
+ tokens = data["input_ids"][None,:]
54
+ all_context = decode_sequence(tokens[0].cpu().numpy())
55
+ list_sequences_msa = [reorder_masked_sequence(elem+"<cls>") for elem in all_context.split("<cls>")[1:-1]]
56
+
57
+ # sequence_df["hamming"] = pd.Series(dtype=object)
58
+ sequence_df["min_hamming"] = pd.Series()
59
+ sequence_df["median_hamming"] = pd.Series()
60
+ sequence_df["mean_hamming"] = pd.Series()
61
+ sequence_df["std_hamming"] = pd.Series()
62
+
63
+ for seq in tqdm(list(sequence_df["sequence"])):
64
+
65
+ all_hamming = []
66
+ for ctx_seq in list_sequences_msa:
67
+ if ctx_seq == seq:
68
+ continue
69
+ else:
70
+ hamming, _, _ = align_sequences(ctx_seq, seq , print_alignments=False)
71
+ all_hamming.append(hamming)
72
+
73
+ # sequence_df.loc[sequence_df["sequence"] == seq, "hamming"] = [all_hamming]
74
+ sequence_df.loc[sequence_df["sequence"] == seq, "min_hamming"] = np.min(all_hamming)
75
+ sequence_df.loc[sequence_df["sequence"] == seq, "median_hamming"] = np.median(all_hamming)
76
+ sequence_df.loc[sequence_df["sequence"] == seq, "mean_hamming"] = np.mean(all_hamming)
77
+ sequence_df.loc[sequence_df["sequence"] == seq, "std_hamming"] = np.std(all_hamming)
78
+
79
+ return sequence_df
80
+
protxlstm/applications/generation_utils/score_hmmer.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import string
2
+ from Bio import SeqIO
3
+ import pyhmmer
4
+ from tqdm import tqdm
5
+
6
+ alphabet = pyhmmer.easel.Alphabet.amino()
7
+
8
+ # This is an efficient way to delete lowercase characters and insertion characters from a string
9
+ deletekeys = dict.fromkeys(string.ascii_lowercase)
10
+ deletekeys["."] = None
11
+ deletekeys["*"] = None
12
+ translation = str.maketrans(deletekeys)
13
+
14
+ def remove_insertions(sequence: str) -> str:
15
+ """ Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """
16
+ return sequence.translate(translation)
17
+
18
+ def read_msa(filename: str):
19
+ """ Reads the sequences from an MSA file, automatically removes insertions."""
20
+ return [(record.description, remove_insertions(str(record.seq))) for record in SeqIO.parse(filename, "fasta")]
21
+
22
+ def read_msa_unaligned(filename: str):
23
+ """ Reads the sequences from an MSA file, removes only . - and * characters."""
24
+ return [(record.description, str(record.seq).replace(".","").replace("-","").replace("*","").upper()) for record in SeqIO.parse(filename, "fasta")]
25
+
26
+ def check_msa(msa):
27
+ """ Checks if there are any repeated sequences in the MSA"""
28
+ seqs = set()
29
+ for el in msa:
30
+ seqs.add(el[1])
31
+ assert len(seqs) == len(msa), "There are repeated sequences in the MSA"
32
+
33
+ def make_hmm_from_a3m_msa(msa_filepath, hmm_filename=None):
34
+ # Load MSA from a3m
35
+ msa_tup = read_msa(msa_filepath)
36
+ # check_msa(msa_tup)
37
+ # Create digitized MSA block
38
+ all_seqs = [pyhmmer.easel.TextSequence(name=str(i).encode("utf-8"), sequence=seq) for i, (idz, seq) in enumerate(msa_tup)]
39
+ msa = pyhmmer.easel.TextMSA(name=b"msa", sequences=all_seqs)
40
+ msa = msa.digitize(alphabet)
41
+ # Fit HMM
42
+ builder = pyhmmer.plan7.Builder(alphabet)
43
+ background = pyhmmer.plan7.Background(alphabet)
44
+ hmm, _, _ = builder.build_msa(msa, background)
45
+ if hmm_filename is not None:
46
+ with open(f"{hmm_filename}.hmm", "wb") as output_file:
47
+ hmm.write(output_file)
48
+ return hmm
49
+
50
+ def align_and_score_sequences_in_a3m_with_hmm(hmm, sequences_path=None, sequences_list=None):
51
+ if sequences_list is not None:
52
+ msa = sequences_list
53
+ all_seqs = [pyhmmer.easel.TextSequence(name=str(i).encode("utf-8"), sequence=seq) for i, seq in enumerate(sequences_list)]
54
+ elif sequences_path is not None:
55
+ # Load sequences from a3m
56
+ msa = read_msa_unaligned(sequences_path)
57
+ all_seqs = [pyhmmer.easel.TextSequence(name=str(i).encode("utf-8"), sequence=seq) for i, (idz, seq) in enumerate(msa)]
58
+ else:
59
+ raise NotImplementedError("Missing sequences to align/score")
60
+ # Create digitized Sequence block
61
+ seq_block = pyhmmer.easel.TextSequenceBlock(all_seqs)
62
+ seq_block = seq_block.digitize(alphabet)
63
+ # Get all hits from the hmm
64
+ background = pyhmmer.plan7.Background(alphabet)
65
+ pipeline = pyhmmer.plan7.Pipeline(alphabet, background=background, bias_filter=False, F1=1.0, F2=1.0, F3=1.0)
66
+ hits = pipeline.search_hmm(hmm, seq_block)
67
+ if len(hits) != len(msa):
68
+ print(f"Number of hits: {len(hits)} is different from the number of sequences in the MSA: {len(msa)}")
69
+ # Extract hits
70
+ all_hits = {}
71
+ for hit in hits:
72
+ idz, score, evalue = hit.name, hit.score, hit.evalue
73
+ i = int(idz.decode("utf-8"))
74
+ seq = msa[i][1] if sequences_path is not None else sequences_list[i]
75
+ all_hits[seq] = {"score": score, "evalue": evalue}
76
+ return all_hits
77
+
78
+
79
+ def score_hmmer(sequence_df, family_idx, data_dir = f"./data/"):
80
+
81
+ assert len(set(list(sequence_df["family"]))) == 1 and sequence_df["family"].iloc[0] == family_idx
82
+
83
+ family_id = sequence_df["family_id"].iloc[0]
84
+ msa_filepath = f"{data_dir}/a3m_files/{family_id}/a3m/uniclust30.a3m"
85
+ try:
86
+ hmm = make_hmm_from_a3m_msa(msa_filepath)
87
+ except:
88
+ raise Exception(f"Missing MSA of family {family_id}")
89
+
90
+ # align sequences
91
+ sequences = list(sequence_df["sequence"])
92
+ scores = align_and_score_sequences_in_a3m_with_hmm(hmm, sequences_list=sequences)
93
+
94
+ # save the scores associated to each sequence in the main df in the columns "score" and "evalue"
95
+ for seq in tqdm(sequences):
96
+ sequence_df.loc[sequence_df["sequence"] == seq, "score_gen"] = scores[seq]["score"] if seq in scores.keys() else 0
97
+ sequence_df.loc[sequence_df["sequence"] == seq, "evalue_gen"] = scores[seq]["evalue"] if seq in scores.keys() else 1
98
+
99
+ return sequence_df
100
+
101
+
102
+
protxlstm/applications/generation_utils/score_structure.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from Bio.PDB import PDBParser
2
+ import torch
3
+ from tqdm import tqdm
4
+ from transformers import EsmForProteinFolding
5
+
6
+ from protxlstm.utils import MASK_TO_ID
7
+
8
+
9
+ pdb_parser = PDBParser()
10
+
11
+
12
+ def compute_structure(seq, model):
13
+ def keep_sequence(seq, l):
14
+ if len(seq) > l:
15
+ return False
16
+ for mm in list(MASK_TO_ID.keys())+["<eos>", "<pad>", "<unk>", "<mask>", "<cls>", "<null_1>", "." , "-"]:
17
+ if mm in seq:
18
+ return False
19
+ return True
20
+ keep = keep_sequence(seq, l=750)
21
+ if keep:
22
+ with torch.no_grad():
23
+ output = model.infer([seq])
24
+ # pdb = model.output_to_pdb(output)
25
+ ptm = output["ptm"].item()
26
+ pae = output["predicted_aligned_error"].cpu().numpy()
27
+ mean_plddt = ((output["plddt"] * output["atom37_atom_exists"]).sum(dim=(1, 2)) / output["atom37_atom_exists"].sum(dim=(1, 2))).item()
28
+ pos_plddt = ((output["plddt"] * output["atom37_atom_exists"]).sum(dim=(2,)) / output["atom37_atom_exists"].sum(dim=(2,))).cpu().numpy()
29
+ else:
30
+ print(f"Sequence is invalid.")
31
+ ptm, pae, mean_plddt, pos_plddt = 0, 0 ,0 , 0
32
+ return ptm, pae, mean_plddt, pos_plddt
33
+
34
+
35
+ def score_structure(sequence_df, family_idx):
36
+
37
+ assert len(set(list(sequence_df["family"]))) == 1 and sequence_df["family"].iloc[0] == family_idx
38
+
39
+ device="cuda:0"
40
+
41
+ # Import the folding model
42
+ model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True)
43
+
44
+ model = model.cuda(device)
45
+ model.esm = model.esm.half()
46
+ torch.backends.cuda.matmul.allow_tf32 = True
47
+
48
+ sequences = list(sequence_df["sequence"])
49
+ for seq in tqdm(sequences):
50
+
51
+ ptm, pae, mean_plddt, pos_plddt = compute_structure(seq, model)
52
+ sequence_df.loc[sequence_df["sequence"] == seq, "ptm"] = ptm
53
+ sequence_df.loc[sequence_df["sequence"] == seq, "mean_plddt"] = mean_plddt
54
+
55
+ return sequence_df
protxlstm/applications/msa_sampler.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Original code from ProtMamba under Apache License 2.0.
2
+ #
3
+ # Modifications made by Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen
4
+ # - Modify handling of weights in `MSASampler`
5
+
6
+ import math
7
+ import os
8
+ from typing import Optional, Callable
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+ from protxlstm.utils import AA_TO_ID
14
+
15
+
16
+ def compute_hamming_csim_torch(
17
+ seqs: torch.Tensor,
18
+ ungapped_msa: torch.Tensor,
19
+ gap_token: int,
20
+ gap_token_mask: int,
21
+ ) -> torch.Tensor:
22
+ return (seqs.unsqueeze(1) == ungapped_msa).sum(dim=2)
23
+
24
+ def _compute_homology_weights(
25
+ ungapped_msa: np.ndarray,
26
+ gap_token: int,
27
+ gap_token_mask: int,
28
+ theta: float,
29
+ hamming_csim_func: Callable,
30
+ max_memory: int = 20,
31
+ can_use_torch: bool = True,
32
+ ) -> np.ndarray:
33
+ use_torch = can_use_torch
34
+ if use_torch:
35
+ hamming_csim_func = compute_hamming_csim_torch
36
+ batch_size = math.floor(
37
+ 2
38
+ * 1024
39
+ * 1024
40
+ * 1024
41
+ / (ungapped_msa.shape[0] * ungapped_msa.shape[1])
42
+ * max_memory
43
+ / 40
44
+ )
45
+
46
+ batch_size = 1 if batch_size == 0 else batch_size
47
+
48
+ neighbors = []
49
+ if not use_torch:
50
+ masked_ungapped_msa = ungapped_msa.copy()
51
+ else:
52
+ ungapped_msa = torch.from_numpy(ungapped_msa).byte()
53
+ masked_ungapped_msa = ungapped_msa.clone()
54
+ masked_ungapped_msa[masked_ungapped_msa == gap_token] = gap_token_mask
55
+ for b_start in range(0, len(ungapped_msa), batch_size):
56
+ b_end = b_start + batch_size
57
+ seqs = ungapped_msa[b_start:b_end]
58
+
59
+ sim = hamming_csim_func(
60
+ seqs=seqs,
61
+ ungapped_msa=masked_ungapped_msa,
62
+ gap_token=gap_token,
63
+ gap_token_mask=gap_token_mask,
64
+ )
65
+ if not use_torch:
66
+ sim = sim / (seqs != gap_token).sum(axis=1, keepdims=True)
67
+ d = 1 - sim
68
+ d = d.clamp(0, 1)
69
+ this_neighbors = (d <= theta).sum(axis=1)
70
+ else:
71
+ sim = sim / (seqs != gap_token).sum(dim=1, keepdim=True)
72
+ d = 1 - sim
73
+ # fillna
74
+ d[torch.isnan(d)] = 0
75
+ d = d.clamp(0, 1)
76
+ this_neighbors = (d <= theta).sum(dim=1).cpu()
77
+ neighbors.append(this_neighbors)
78
+ return np.concatenate(neighbors)
79
+
80
+ def compute_homology_weights(
81
+ ungapped_msa: np.ndarray,
82
+ theta: float = 0.2,
83
+ gap_token: int = AA_TO_ID["-"],
84
+ gap_token_mask: int = 255,
85
+ hamming_csim_func: Callable = compute_hamming_csim_torch,
86
+ ) -> tuple[int, np.ndarray]:
87
+ """
88
+ Calculate the effective number of sequences and sampling probability for the NEIGHBORS and NEIGHBORS_NO_LIMIT sampling methods using numpy.
89
+
90
+ Parameters:
91
+
92
+ ungapped_msa (np.ndarray): The MSA (from .fa).
93
+ theta (float, optional): A parameter used to determine the similarity between sequences. Default is 0.2.
94
+ gap_token (int, optional): The token representing gaps in the (Uniprot21 encoded) MSA. Default is 20.
95
+ gap_token_mask (int): token for masking gaps. should be a token not representing any other value.
96
+
97
+ Returns:
98
+
99
+ tuple[int, np.ndarray]: A tuple containing the effective number of sequences and the sampling probability for each sequence in the MSA.
100
+ """
101
+ neighbors = _compute_homology_weights(
102
+ ungapped_msa=ungapped_msa,
103
+ gap_token=gap_token,
104
+ gap_token_mask=gap_token_mask,
105
+ theta=theta,
106
+ hamming_csim_func=hamming_csim_func,
107
+ )
108
+ n_eff = np.sum(1 / neighbors)
109
+
110
+ p = 1 / neighbors
111
+ p /= np.sum(p)
112
+ return n_eff, p
113
+
114
+ class MSASampler:
115
+
116
+ def __init__(self, max_similarity, max_dissimilarity, force_include_first=True):
117
+ self.max_similarity = max_similarity
118
+ self.max_dissimilarity = max_dissimilarity
119
+ self.force_include_first = force_include_first
120
+ self.theta = 0.2
121
+
122
+ def _get_sim_filtered_idxs(self, msa: np.ndarray) -> np.ndarray:
123
+ nonnormalized_sim = (msa == msa[[0]]).sum(axis=1)
124
+ normfactor = msa.shape[1]
125
+ norm_sim = nonnormalized_sim / normfactor
126
+
127
+ assert (norm_sim.min() >= 0) and (norm_sim.max() <= 1)
128
+ dsim = 1 - norm_sim
129
+
130
+ max_sim_filter = norm_sim <= self.max_similarity
131
+ max_dissim_filter = dsim <= self.max_dissimilarity
132
+ return np.where(max_sim_filter & max_dissim_filter)[0]
133
+
134
+ def get_weights(
135
+ self, msa: np.ndarray,
136
+ ) -> tuple[Optional[float], Optional[np.ndarray]]:
137
+ return compute_homology_weights(
138
+ ungapped_msa=msa,
139
+ theta=self.theta,
140
+ gap_token_mask=255,
141
+
142
+ )
143
+
144
+ def get_sample_idxs(
145
+ self,
146
+ msa: np.ndarray,
147
+ size: int = 1,
148
+ random = False,
149
+ msa_weights_path = None,
150
+ seed = 0,
151
+ ) -> np.ndarray:
152
+
153
+ np.random.seed(seed)
154
+
155
+ if random:
156
+ return np.random.choice(len(msa), replace=False, size=size) if len(msa) >= size else np.arange(len(msa))
157
+
158
+ msa = np.array([[AA_TO_ID[aa] for aa in seq.upper()][:len(msa[0])] for seq in msa], dtype=np.uint8)
159
+
160
+ if msa_weights_path and os.path.exists(msa_weights_path):
161
+ weights = np.load(msa_weights_path)
162
+ elif msa_weights_path:
163
+ os.makedirs(os.path.dirname(msa_weights_path), exist_ok=True)
164
+ _, weights = self.get_weights(
165
+ msa=msa,
166
+ )
167
+ np.save(msa_weights_path, weights)
168
+ else:
169
+ _, weights = self.get_weights(
170
+ msa=msa,
171
+ )
172
+
173
+
174
+ original_msa_sample_idxs = np.arange(len(msa))
175
+ sample_idxs = self._get_sim_filtered_idxs(msa)
176
+ original_msa_sample_idxs = original_msa_sample_idxs[sample_idxs]
177
+
178
+ if self.force_include_first:
179
+ original_msa_sample_idxs = np.concatenate(
180
+ [[0], original_msa_sample_idxs[original_msa_sample_idxs != 0]]
181
+ )
182
+ return np.random.choice(len(msa), replace=False, size=size, p=weights / weights.sum()) if len(msa) >= size else original_msa_sample_idxs
183
+
184
+ def sample_msa(msa_sequences, msa_weights_path=None, context_length=200_000, max_context_sequences=200, seed=0, sort=True):
185
+ """Sample MSA sequences for the context"""
186
+ n_sequences = min( context_length // len(msa_sequences[0]), len(msa_sequences) if max_context_sequences == 0 else max_context_sequences ) - 1
187
+ sampler = MSASampler(0.98, 0.7, force_include_first=False)
188
+ sample_idx = sampler.get_sample_idxs(
189
+ msa_sequences, size=n_sequences, msa_weights_path=msa_weights_path, seed=seed
190
+ )
191
+
192
+ # Sort sequences from least similar to most similar and add wild type target sequence
193
+ if sort:
194
+ context_sequences = [msa_sequences[i] for i in sample_idx][::-1]
195
+
196
+ return context_sequences
protxlstm/applications/sample_sequences.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm import tqdm
3
+ import pickle
4
+ import os
5
+ import argparse
6
+ import json
7
+
8
+ from protxlstm.dataloaders import ProteinMemmapDataset
9
+ from protxlstm.generation import generate_sequence
10
+ from protxlstm.utils import (
11
+ AA_TO_ID,
12
+ load_model,
13
+ )
14
+ from protxlstm.models.xlstm import xLSTMLMHeadModel
15
+ from protxlstm.models.mamba import MambaLMHeadModelwithPosids
16
+
17
+
18
+ def sample_sequences(dataset,
19
+ model,
20
+ family_idx,
21
+ params,
22
+ n_samples_per_family,
23
+ max_length=1000,
24
+ chunk_chunk_size=2**15,
25
+ save_path=None,
26
+ device="cuda:0"):
27
+ """
28
+ Function to sample sequences from the model. Given a dataset, a list of families (their indexes in the dataset)
29
+ and a set of generating parameters, it generates `n_samples_per_family` sequences for each family and each parameter set.
30
+ The function returns a dictionary with the following structure:
31
+ gen_seqs = {family_idx: {parameters: {sequence: perplexity}}}
32
+ The parameters are in a list of tuples with the following structure:
33
+ parameters_list = [(nr_seqs_ctx, temperature, top_k, top_p)]
34
+ """
35
+ gen_seqs = {}
36
+ gen_seqs[family_idx] = {}
37
+ gen_seqs[family_idx][params] = {}
38
+ print(f"Sampling sequences for family {family_idx} and parameters {params}.")
39
+
40
+ n_seqs_ctx , temperature, top_k, top_p = params
41
+ for _ in tqdm(range(n_samples_per_family)):
42
+ # Sample the dataset to get the input
43
+ data = dataset[family_idx]
44
+ tokens = data["input_ids"][None,:].to(device)
45
+ pos_ids = data["position_ids"][None,:].to(device)
46
+
47
+ start_seqs = torch.argwhere(tokens[0]==0)[:,0].cpu().numpy()
48
+
49
+ n_seqs_ctx = len(start_seqs) if len(start_seqs) < n_seqs_ctx else n_seqs_ctx
50
+ L = start_seqs[n_seqs_ctx]+1
51
+ context_tokens = tokens[:,:L]
52
+ context_pos_ids = pos_ids[:,:L]
53
+ is_fim={}
54
+
55
+ # Generate the new sequence
56
+ output = generate_sequence(model,
57
+ context_tokens,
58
+ position_ids=context_pos_ids,
59
+ is_fim=is_fim,
60
+ max_length=(L+max_length),
61
+ temperature=temperature,
62
+ top_k=top_k,
63
+ top_p=top_p,
64
+ return_dict_in_generate=True,
65
+ output_scores=True,
66
+ eos_token_id=torch.tensor([AA_TO_ID["<cls>"]]).to(device),
67
+ chunk_chunk_size=chunk_chunk_size,
68
+ device=device)
69
+
70
+ # Get the perplexity of the generated sequence
71
+ output_seq = output["generated"]
72
+ loss = torch.nn.functional.cross_entropy(torch.from_numpy(output["scores"]).permute(0, 2, 1),
73
+ torch.from_numpy(output["generated_tokens"][0][None,:]))
74
+
75
+ # save only sequences with length < max_length
76
+ if len(output_seq[0]) < max_length:
77
+
78
+ gen_seqs[family_idx][params][output_seq[0]] = {"perplexity": torch.exp(loss).item()}
79
+
80
+ if save_path is not None:
81
+ if not os.path.exists("evaluation/generation/generated_sequences"):
82
+ os.mkdir("evaluation/generation/generated_sequences")
83
+ if not os.path.exists(save_path):
84
+ os.mkdir(save_path)
85
+ with open(f'{save_path}/{family_idx}_{params}_{n_samples_per_family}', "wb") as f:
86
+ pickle.dump(gen_seqs, f)
87
+ print(f"Sequences saved for family {family_idx} and parameters {params}")
88
+
89
+ return gen_seqs
90
+
91
+ def generate_sequences(model_name,
92
+ checkpoint,
93
+ family_idxs=[],
94
+ parameters_list=[],
95
+ n_samples_per_family = 100,
96
+ chunk_size=1024,
97
+ chunk_chunk_size=2**15,
98
+ data_dir="data/",
99
+ device="cuda:0"
100
+ ):
101
+
102
+ # Load the test dataset
103
+ fim_strategy = "multiple_span"
104
+ mask_fraction = 0.2
105
+
106
+ dataset = ProteinMemmapDataset(
107
+ msa_memmap_path=f"{data_dir}open_protein_set_memmap.dat",
108
+ msa_memmap_meta_path=f"{data_dir}open_protein_set_memmap_indices.csv",
109
+ subset_path=f"{data_dir}cluster_testing_set.txt",
110
+ sample=False,
111
+ max_msa_len=-1,
112
+ reverse=False,
113
+ seed=0,
114
+ troubleshoot=False,
115
+ fim_strategy=fim_strategy,
116
+ always_mask=False,
117
+ max_position_embeddings=2048,
118
+ max_seq_position_embeddings=512,
119
+ add_position_ids="1d",
120
+ mask_fraction=mask_fraction
121
+ )
122
+
123
+ if model_name == "xlstm":
124
+ model_class = xLSTMLMHeadModel
125
+ elif model_name == "mamba":
126
+ model_class = MambaLMHeadModelwithPosids
127
+
128
+ save_path = f"evaluation/generation/generated_sequences/{checkpoint.split('/')[-1]}"
129
+
130
+ if model_name == "xlstm":
131
+ config_update_kwargs = {
132
+ "mlstm_backend": "chunkwise_variable",
133
+ "mlstm_chunksize": chunk_size,
134
+ "mlstm_return_last_state": True
135
+ }
136
+ else:
137
+ config_update_kwargs = {}
138
+
139
+
140
+ #load the model
141
+ model = load_model(checkpoint,
142
+ model_class=model_class,
143
+ device=device,
144
+ dtype=torch.bfloat16,
145
+ **config_update_kwargs,
146
+ )
147
+ model = model.eval()
148
+ print("Model loaded.")
149
+
150
+ for family_idx in family_idxs:
151
+ for params in parameters_list:
152
+ params = tuple(params)
153
+ if not os.path.exists(f'{save_path}/{family_idx}_{params}_{n_samples_per_family}'):
154
+ gen_seqs = sample_sequences(
155
+ dataset=dataset,
156
+ model=model,
157
+ family_idx=family_idx,
158
+ params=params,
159
+ n_samples_per_family=n_samples_per_family,
160
+ chunk_chunk_size=chunk_chunk_size,
161
+ save_path=save_path,
162
+ device=device)
163
+
164
+ print(f"Sampled {len(gen_seqs[family_idx][params])} valid sequences.")
165
+ else:
166
+ print(f"Sequences for family {family_idx} and parameters {params} already exist.")
167
+
168
+
169
+ if __name__ == "__main__":
170
+
171
+ parser = argparse.ArgumentParser(
172
+ description="Generate sequences."
173
+ )
174
+ parser.add_argument("--model_name", type=str, help="Either 'xlstm' or 'mamba'.")
175
+ parser.add_argument("--checkpoint", type=str, help="Path to model checkpoint.")
176
+ parser.add_argument("--family_idxs", type=str, help="List of family indices.")
177
+ parser.add_argument("--parameters_list", type=str, help="List of sampling parameters.")
178
+ parser.add_argument("--n_samples_per_family", type=int, default=100, help="Number of sequences to sample per family and parameter set.")
179
+ parser.add_argument("--chunk_size", type=int, default=1024, help="Chunk size for xLSTM context encoding.")
180
+ parser.add_argument("--chunk_chunk_size", type=int, default=2*15, help="Length of context sequence part processed at once.")
181
+ parser.add_argument("--data_dir", type=str, default="data/", help="Path to dataset.")
182
+ parser.add_argument("--device", type=str, default="cuda:0", help="Device.")
183
+
184
+ args = parser.parse_args()
185
+
186
+ family_idxs = json.loads(args.family_idxs)
187
+ parameters_list = json.loads(args.parameters_list)
188
+
189
+ # Run sequence generation
190
+ generate_sequences(
191
+ model_name=args.model_name,
192
+ checkpoint=args.checkpoint,
193
+ family_idxs=family_idxs,
194
+ parameters_list=parameters_list,
195
+ n_samples_per_family=args.n_samples_per_family,
196
+ chunk_size=args.chunk_size,
197
+ chunk_chunk_size=args.chunk_chunk_size,
198
+ data_dir=args.data_dir,
199
+ device=args.device,
200
+ )
protxlstm/applications/score_sequences.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pickle
4
+
5
+ from generation_utils.create_sequence_df import create_sequence_df
6
+ from generation_utils.score_hamming import score_hamming
7
+ from generation_utils.score_hmmer import score_hmmer
8
+ from generation_utils.score_structure import score_structure
9
+
10
+
11
+ def score_sequences(model_name,
12
+ family_idx,
13
+ num_sequences = 100,
14
+ data_dir = "data/"):
15
+
16
+ if os.path.isfile(f"evaluation/generation/evaluations/{model_name}/sequence_df_{family_idx}"):
17
+ with open(f"evaluation/generation/evaluations/{model_name}/sequence_df_{family_idx}", "rb") as f:
18
+ sequence_df = pickle.load(f)
19
+ else:
20
+ sequence_df = create_sequence_df(model_name, family_idx, data_dir = data_dir, num_sequences = num_sequences)
21
+ if not os.path.exists("evaluation/generation/evaluations/"):
22
+ os.mkdir("evaluation/generation/evaluations/")
23
+ if not os.path.exists(f"evaluation/generation/evaluations/{model_name}/"):
24
+ os.mkdir(f"evaluation/generation/evaluations/{model_name}/")
25
+ with open(f"evaluation/generation/evaluations/{model_name}/sequence_df_{family_idx}", "wb") as f:
26
+ pickle.dump(sequence_df, f)
27
+
28
+ if not "min_hamming" in sequence_df.columns:
29
+ sequence_df = score_hamming(sequence_df, family_idx, data_dir)
30
+ with open(f"evaluation/generation/evaluations/{model_name}/sequence_df_{family_idx}", "wb") as f:
31
+ pickle.dump(sequence_df, f)
32
+
33
+ if not "score_gen" in sequence_df.columns:
34
+ sequence_df = score_hmmer(sequence_df, family_idx, data_dir)
35
+ with open(f"evaluation/generation/evaluations/{model_name}/sequence_df_{family_idx}", "wb") as f:
36
+ pickle.dump(sequence_df, f)
37
+
38
+ if not "ptm" in sequence_df.columns:
39
+ sequence_df = score_structure(sequence_df, family_idx)
40
+ with open(f"evaluation/generation/evaluations/{model_name}/sequence_df_{family_idx}", "wb") as f:
41
+ pickle.dump(sequence_df, f)
42
+
43
+ return sequence_df
44
+
45
+
46
+ if __name__ == "__main__":
47
+
48
+ parser = argparse.ArgumentParser(
49
+ description="Generate sequences."
50
+ )
51
+ parser.add_argument("--model_name", type=str, help="Either 'xlstm' or 'mamba'.")
52
+ parser.add_argument("--family_idx", type=int, help="Family index.")
53
+ parser.add_argument("--num_sequences", type=int, default=100, help="Number of sequences.")
54
+ parser.add_argument("--data_dir", type=str, default="./data/", help="Path to dataset.")
55
+
56
+ args = parser.parse_args()
57
+
58
+ sequence_df = score_sequences(args.model_name, args.family_idx, args.num_sequences, args.data_dir)
protxlstm/checkpoints/small/config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"mlstm_block": {"mlstm": {"proj_factor": 2.0, "round_proj_up_dim_up": true, "round_proj_up_to_multiple_of": 64, "_proj_up_dim": 1024, "conv1d_kernel_size": 4, "qkv_proj_blocksize": 4, "num_heads": 4, "embedding_dim": 512, "bias": false, "dropout": 0.0, "context_length": 2048, "backend": "chunkwise", "chunk_size": 1024, "return_last_state": false, "_num_blocks": 16, "_inner_embedding_dim": 1024}}, "slstm_block": {"slstm": {"hidden_size": 512, "num_heads": 4, "num_states": 4, "backend": "cuda", "function": "slstm", "bias_init": "powerlaw_blockdependent", "recurrent_weight_init": "zeros", "_block_idx": 0, "_num_blocks": 16, "num_gates": 4, "gradient_recurrent_cut": false, "gradient_recurrent_clipval": null, "forward_clipval": null, "batch_size": 8, "input_shape": "BSGNH", "internal_input_shape": "SBNGH", "output_shape": "BNSH", "constants": {}, "dtype": "bfloat16", "dtype_b": "float32", "dtype_r": "bfloat16", "dtype_w": "bfloat16", "dtype_g": "bfloat16", "dtype_s": "bfloat16", "dtype_a": "float32", "enable_automatic_mixed_precision": true, "initial_val": 0.0, "embedding_dim": 512, "conv1d_kernel_size": 4, "dropout": 0.0}, "feedforward": {"proj_factor": 1.3, "round_proj_up_dim_up": true, "round_proj_up_to_multiple_of": 64, "_proj_up_dim": 0, "act_fn": "gelu", "embedding_dim": -1, "dropout": 0.0, "bias": false, "ff_type": "ffn_gated", "_num_blocks": 1}, "_num_blocks": 16, "_block_idx": 0}, "context_length": 2048, "num_blocks": 16, "embedding_dim": 512, "add_post_blocks_norm": true, "bias": false, "dropout": 0.0, "checkpoint_blocks": true, "slstm_at": [], "_block_map": "0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0", "vocab_size": 38, "tie_weights": false, "weight_decay_on_embedding": false, "add_embedding_dropout": false, "position_embeddings": "rot_1d", "max_position_embeddings": 2048, "max_seq_position_embeddings": 512, "rope_base_frequency": 500000}
protxlstm/checkpoints/small/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6bcae4c2a893afed859ca9b5926d24e8f7d0b22eca198e4aa950b80909be8e50
3
+ size 207533690
protxlstm/checkpoints/small/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ba59f1fd544a5d9f6c4adb40730cf90b8f69a772df838246f724586cb1d602a
3
+ size 103773526
protxlstm/checkpoints/small/rng_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5480500274058efdf9caa959d35a42a948ff3dc8536e082b9bc22f2ecd423108
3
+ size 14244
protxlstm/checkpoints/small/scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:25a26dbf75285e97210697d5608b6d76ef35aa0d2879be319ef2785f881153b9
3
+ size 1000
protxlstm/checkpoints/small/trainer_state.json ADDED
The diff for this file is too large to render. See raw diff
 
protxlstm/data.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import os
3
+
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+
7
+ from protxlstm.utils import load_sequences_from_msa_file, tokenizer
8
+
9
+ def process_msa(msa_item):
10
+ msa_name, msa_path = msa_item
11
+ # Load an a3m file with all the context sequences
12
+ msa = load_sequences_from_msa_file(msa_path)
13
+ # Tokenize the sequences and concatenate them into a single array
14
+ tokens = tokenizer(msa, concatenate=True)
15
+ tokens = tokens.numpy()[0]
16
+ return msa_name, tokens
17
+
18
+ def main(data_dir, output_dir):
19
+ msa_paths = {k: os.path.join(data_dir, k, 'a3m/uniclust30.a3m') for k in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, k))}
20
+ msa_items = list(msa_paths.items())
21
+
22
+ dataset_dictionary = {}
23
+ total_length = 0
24
+
25
+ # First pass: calculate total length of all concatenated arrays
26
+ for item in tqdm(msa_items):
27
+ try:
28
+ k, v = process_msa(item)
29
+ dataset_dictionary[k] = v
30
+ total_length += len(v)
31
+ except:
32
+ print(f"Error processing {item}")
33
+
34
+ # Initialize the memmap array with the calculated total length
35
+ memmap_path = os.path.join(output_dir, 'open_protein_set_memmap.dat')
36
+ concatenated_array = np.memmap(memmap_path, dtype='int8', mode='w+', shape=(total_length,))
37
+
38
+ with open(f'{output_dir}/open_protein_set_memmap_indices.csv', 'w', newline='') as csvfile:
39
+ csvwriter = csv.writer(csvfile)
40
+
41
+ csvwriter.writerow(['msa_id', 'Start', 'End'])
42
+
43
+ start_index = 0
44
+ for key, array in dataset_dictionary.items():
45
+ end_index = start_index + len(array) - 1
46
+ concatenated_array[start_index:end_index + 1] = array # Write to memmap
47
+ csvwriter.writerow([key, start_index, end_index])
48
+ start_index = end_index + 1
49
+
50
+ # Ensure the data is written to disk
51
+ concatenated_array.flush()
52
+
53
+
54
+ if __name__ == "__main__":
55
+ data_dir = 'data/a3m_files'
56
+ output_dir = 'data/'
57
+ main(data_dir, output_dir)
58
+
59
+
60
+
protxlstm/dataloaders.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Original code from ProtMamba under Apache License 2.0.
2
+ #
3
+ # Modifications made by Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen
4
+ # - Uniclust30_Dataset renamed to ProteinMemmapDataset
5
+ # - Dataset input file format changed for more efficient dataloading
6
+ # - Option to use only a subset
7
+ # - DataCollatorForUniclust30Dataset renamed to ProteinDataCollator
8
+ # - Add sequence padding
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import torch
13
+ from torch.utils.data import DataLoader, Dataset
14
+ from typing import Dict, Optional, Sequence
15
+
16
+ from protxlstm.fim import MultipleSpanFIM, NoFIM, SingleSpanFIM
17
+ from protxlstm.utils import AA_TO_ID
18
+
19
+
20
+ # Make dataset
21
+ class ProteinMemmapDataset(Dataset):
22
+ """
23
+ ProteinMemmapDataset is a PyTorch Dataset class for handling memory-mapped datasets of protein multiple sequence alignments (MSAs).
24
+
25
+ This class imports MSA data stored in memmap format and associated metadata CSVs. It supports flexible
26
+ data sampling strategies and inpainting methods for sequence manipulation and training purposes.
27
+
28
+ Args:
29
+ msa_memmap_path (str): Path to the memory-mapped file containing the MSA clusters.
30
+ msa_memmap_meta_path (str): Path to the CSV file with metadata linking MSA Cluster IDs and indices in the memmap array.
31
+ subset_path (str, optional): Path to a CSV file specifying a subset of cluster IDs to use.
32
+ sample (bool, optional): If True, randomly samples sequences from each cluster; otherwise, loads all sequences and shuffles them.
33
+ max_msa_len (int, optional): Maximum length of the MSA sequences to include. Defaults to -1 (no limit).
34
+ reverse (bool, optional): If True, reverses sequences with a probability of 0.5 and moves the last token to the front.
35
+ seed (int, optional): Random seed for reproducibility. Defaults to 42.
36
+ troubleshoot (bool, optional): If True, prints debugging information. Defaults to False.
37
+ fim_strategy (str, optional): Strategy for inpainting ("no-scramble", "one_span", or "multiple_span").
38
+ max_patches (int, optional): Number of patches for inpainting. Used when fim_strategy is "multiple_span".
39
+ mask_fraction (float, optional): Fraction of the patches to mask. Used when fim_strategy is "multiple_span".
40
+ always_mask (bool, optional): If True, ensures masking is applied in the inpainting process.
41
+ max_position_embeddings (int, optional): Maximum position embeddings. Defaults to 2048.
42
+ max_seq_position_embeddings (int, optional): Maximum sequence position embeddings for 2D positional IDs. Defaults to 512.
43
+ add_position_ids (str, optional): Type of position IDs to add ("none", "1d", or "2d"). Defaults to "1d".
44
+ """
45
+
46
+ _FIM = {"no-scramble": NoFIM, "one_span": SingleSpanFIM, "multiple_span": MultipleSpanFIM}
47
+ _POSIDS = {"none", "1d", "2d"}
48
+
49
+ def __init__(self,
50
+ msa_memmap_path=None,
51
+ msa_memmap_meta_path=None,
52
+ subset_path=None,
53
+ sample=False,
54
+ max_msa_len=-1,
55
+ reverse=False,
56
+ seed=42,
57
+ troubleshoot=False,
58
+ fim_strategy="no-scramble",
59
+ max_patches=5,
60
+ mask_fraction=0.2,
61
+ always_mask=False,
62
+ max_position_embeddings=2048,
63
+ max_seq_position_embeddings=512,
64
+ add_position_ids="1d", ):
65
+
66
+ np.random.seed(seed)
67
+
68
+ if msa_memmap_path:
69
+ self.dataset = np.memmap(msa_memmap_path, dtype=np.int8, mode='r')
70
+ self.dataset_meta = pd.read_csv(msa_memmap_meta_path)
71
+ if subset_path:
72
+ subset_ids = pd.read_csv(subset_path, header=None, names=['ID'])['ID'].tolist()
73
+ self.dataset_meta = self.dataset_meta[self.dataset_meta['msa_id'].isin(subset_ids)]
74
+ else:
75
+ self.dataset = None
76
+
77
+ self.sample = sample
78
+ self.max_msa_len = max_msa_len
79
+ self.reverse = reverse
80
+ self.fim_strategy = fim_strategy
81
+ if fim_strategy in ProteinMemmapDataset._FIM:
82
+ self.fim = ProteinMemmapDataset._FIM[fim_strategy](max_patches=max_patches,
83
+ mask_fraction=mask_fraction,
84
+ always_mask=always_mask,
85
+ add_position_ids=add_position_ids != "none",
86
+ troubleshoot=troubleshoot)
87
+ else:
88
+ raise ValueError(f'Fill in the middle stragy "{fim_strategy}" not recognized.')
89
+
90
+ self.max_position_embeddings = max_position_embeddings
91
+ self.max_seq_position_embeddings = max_seq_position_embeddings
92
+ self.add_position_ids = add_position_ids
93
+
94
+ self.troubleshoot = troubleshoot
95
+
96
+ def __len__(self):
97
+ # meta dataframe has one row for each MSA cluster
98
+ return len(self.dataset_meta)
99
+
100
+ def __getitem__(self, idx):
101
+ # get all the sequences in the cluster
102
+ sequences = self.get_sequences(idx)
103
+ # get total number of sequences in the cluster and choose how many to sample
104
+ orig_num_sequences = len(self.get_index_start_of_sequences(sequences))
105
+ num_sequences = np.random.randint(1, orig_num_sequences + 1) if self.sample else orig_num_sequences
106
+ # sample the sequences
107
+ sequences, position_ids = self.sample_sequences(sequences, num_sequences)
108
+ # with probability 0.5, reverse the sequences and move the last token to the front
109
+ sequences, position_ids = self.reverse_sequences(sequences, position_ids) if (
110
+ self.reverse and np.random.rand() > 0.5) else sequences, position_ids
111
+ # limit the length of the MSA
112
+ sequences = sequences[:self.max_msa_len] if self.max_msa_len > 0 else sequences
113
+ if self.add_position_ids != "none":
114
+ position_ids = position_ids[:self.max_msa_len] if self.max_msa_len > 0 else position_ids
115
+ # convert to tensor
116
+ sequences = torch.asarray(sequences, dtype=torch.int64)
117
+ position_ids = torch.asarray(position_ids, dtype=torch.int64).clamp(0,
118
+ self.max_position_embeddings - 1) if self.add_position_ids!="none" else None
119
+
120
+ if self.troubleshoot:
121
+ print(
122
+ f"Cluster {idx} has {orig_num_sequences} sequences, of which {num_sequences} sampled now. Total MSA length: {len(sequences)}")
123
+ if self.add_position_ids == "1d":
124
+ return dict(input_ids=sequences, position_ids=position_ids, labels=sequences)
125
+ if self.add_position_ids == "2d":
126
+ seq_position_ids = (sequences == AA_TO_ID["<cls>"]).int().cumsum(-1).clamp(0,
127
+ self.max_seq_position_embeddings - 1).contiguous()
128
+ return dict(input_ids=sequences, position_ids=position_ids, seq_position_ids=seq_position_ids,
129
+ labels=sequences)
130
+ return dict(input_ids=sequences, labels=sequences)
131
+
132
+ def get_msa_id(self, idx):
133
+ """Get the MSA ID in the cluster with index `idx`."""
134
+ cluster_meta = self.dataset_meta.iloc[idx]
135
+ return cluster_meta.msa_id
136
+
137
+ def get_idx_from_msa_id(self, msa_id):
138
+ """Get `idx` with the MSA ID"""
139
+ return self.dataset_meta[self.dataset_meta.msa_id == msa_id].index[0]
140
+
141
+ def get_sequences(self, idx):
142
+ """Get the sequences in the cluster with index `idx`."""
143
+ cluster_meta = self.dataset_meta.iloc[idx]
144
+ sequences = self.dataset[cluster_meta.Start : cluster_meta.End]
145
+ return sequences
146
+
147
+ def get_index_start_of_sequences(self, sequences):
148
+ """Get the positions of the start of each sequence in the cluster."""
149
+ return np.where(sequences == 0)[0]
150
+
151
+ def reverse_sequences(self, sequence, position_ids=None):
152
+ """Reverse the sequences and move the last token to the front."""
153
+ sequence = sequence[::-1]
154
+ if position_ids is not None:
155
+ position_ids = position_ids[::-1]
156
+ return np.concatenate([sequence[-1:], sequence[:-1]]), np.concatenate(
157
+ [position_ids[-1:], position_ids[:-1]]) if position_ids is not None else None
158
+
159
+ def sample_sequences(self, sequences, num_sequences, shuffle=True):
160
+ """Sample `num_sequences` from the sequences in the cluster."""
161
+ L = len(sequences)
162
+ # get the indexes of the start of each sequence
163
+ inds = self.get_index_start_of_sequences(sequences)
164
+ # check that there are sequences in the cluster and that there are enough of them
165
+ assert len(inds) > 0, "No sequences found in cluster."
166
+ assert len(inds) >= num_sequences, "Not enough sequences in cluster."
167
+ # sample n_sequences randomly from the sequences
168
+ if shuffle:
169
+ which_seqs = np.random.choice(np.arange(len(inds)), num_sequences, replace=False)
170
+ else:
171
+ which_seqs = np.arange(len(inds))[-num_sequences:]
172
+ # get the tuples of start and end indexes of the sequences
173
+ tuples = [(inds[i], inds[i + 1]) if i < len(inds) - 1 else (inds[i], L) for i in which_seqs]
174
+ if self.troubleshoot:
175
+ print(f"Sampled sequences: {tuples}")
176
+ # concatenate the sequences
177
+ sequences, position_ids = self.fim.apply(sequences, tuples)
178
+ return sequences, position_ids
179
+
180
+
181
+
182
+ def make_dataloader(dataset):
183
+ """Basic function to make a dataloader.
184
+ """
185
+ dataloader = DataLoader(dataset)
186
+ return dataloader
187
+
188
+
189
+ class ProteinDataCollator(object):
190
+ """
191
+ Collate examples into a batch, and pad batch to a specified maximum sequence length,
192
+ or to the longest sequence in the batch if max_sequence_length is None.
193
+ """
194
+ def __init__(self, max_sequence_length: Optional[int] = None):
195
+ """
196
+ Initialize the collator with an optional max_sequence_length.
197
+
198
+ Args:
199
+ max_sequence_length (Optional[int]): The maximum sequence length to pad/truncate to.
200
+ If None, pad to the longest sequence in the batch.
201
+ """
202
+ self.max_sequence_length = max_sequence_length
203
+
204
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
205
+
206
+ input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "input_ids"))
207
+
208
+ longest_seq = max(len(seq) for seq in input_ids)
209
+ if self.max_sequence_length is None:
210
+ max_len = longest_seq
211
+ else:
212
+ max_len = self.max_sequence_length
213
+
214
+ input_ids = self.pad_sequences(input_ids, max_len, padding_value=AA_TO_ID["<pad>"])
215
+
216
+ labels = self.pad_sequences(labels, longest_seq, padding_value=AA_TO_ID["<pad>"])
217
+ labels = self.pad_sequences(labels, max_len, padding_value=-100)
218
+
219
+ return_dict = dict(
220
+ input_ids=input_ids,
221
+ labels=labels,
222
+ attention_mask=input_ids.ne(AA_TO_ID["<pad>"])
223
+ )
224
+
225
+ if "position_ids" in instances[0]:
226
+
227
+ position_ids = [instance["position_ids"] for instance in instances]
228
+ position_ids = self.pad_sequences(position_ids, max_len, padding_value=0)
229
+ return_dict["position_ids"] = position_ids
230
+
231
+ if "seq_position_ids" in instances[0]:
232
+ seq_position_ids = [instance["seq_position_ids"] for instance in instances]
233
+ seq_position_ids = self.pad_sequences(seq_position_ids, max_len, padding_value=0)
234
+ return_dict["seq_position_ids"] = seq_position_ids
235
+
236
+ return return_dict
237
+
238
+ def pad_sequences(self, seqs, max_length, padding_value):
239
+ # truncate long sequences (redundant, already done in __getitem__, maybe safe to remove)
240
+ seqs = [seq[:max_length] for seq in seqs]
241
+
242
+ # pad to same length
243
+ seqs = torch.nn.utils.rnn.pad_sequence(seqs, batch_first=True, padding_value=padding_value)
244
+
245
+ # pad to max length
246
+ padding = max_length - seqs.size(1)
247
+ seqs = torch.nn.functional.pad(seqs, (0, padding), value=padding_value)
248
+
249
+ return seqs
protxlstm/fim.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Original code from ProtMamba under Apache License 2.0.
3
+
4
+ from protxlstm.utils import MASK_TO_ID, AA_TO_ID
5
+ import numpy as np
6
+
7
+ class AbstractFIM(object):
8
+ def __init__(self,
9
+ max_patches=5,
10
+ mask_fraction=0.2,
11
+ always_mask=False,
12
+ mask_tokens=MASK_TO_ID,
13
+ eos_token=AA_TO_ID["<eos>"],
14
+ add_position_ids=False,
15
+ troubleshoot=False):
16
+ """
17
+ This class is designed to concatenate sequences based on different scrambling strategies.
18
+ It takes a list of sequences, tuples indicating the start and end indices of each sequence,
19
+ an optional number of patches to sample, and a scrambling strategy as inputs.
20
+ """
21
+ self.troubleshoot = troubleshoot
22
+ self.max_patches = max_patches
23
+ self.mask_fraction = mask_fraction
24
+ self.mask_tokens = mask_tokens
25
+ assert len(
26
+ self.mask_tokens) >= self.max_patches, "Number of mask tokens must be bigger than max number of patches."
27
+ self.eos_token = eos_token
28
+ self.add_position_ids = add_position_ids
29
+ self.always_mask = always_mask
30
+
31
+ def apply(self, sequences, tuples):
32
+ """
33
+ This function concatenates the sequences scrambling each one according to the scrambling strategy.
34
+ """
35
+ input_ids, position_ids = [], []
36
+ for t in tuples:
37
+ seq, pos = self.fim(sequences, t)
38
+ input_ids.extend(seq)
39
+ if self.add_position_ids:
40
+ position_ids.extend(pos)
41
+ if self.add_position_ids:
42
+ return input_ids, position_ids
43
+ return input_ids, None
44
+
45
+ def fim(self, sequences, t):
46
+ """
47
+ This function concatenates the sequence's parts based on the scrambling strategy.
48
+ """
49
+ raise NotImplementedError
50
+
51
+
52
+ class NoFIM(AbstractFIM):
53
+ def __init__(self,
54
+ max_patches=5,
55
+ mask_fraction=0.2,
56
+ always_mask=False,
57
+ mask_tokens=MASK_TO_ID,
58
+ eos_token=AA_TO_ID["<eos>"],
59
+ add_position_ids=False,
60
+ troubleshoot=False):
61
+ super().__init__(max_patches, mask_fraction, always_mask, mask_tokens, eos_token, add_position_ids, troubleshoot)
62
+
63
+ def fim(self, sequences, t):
64
+ """
65
+ This function keeps the sequence identical without any scrambling.
66
+ """
67
+ if self.add_position_ids:
68
+ position_ids = np.arange(t[0], t[1]) - t[0]
69
+ return sequences[t[0]:t[1]], position_ids
70
+ return sequences[t[0]:t[1]], None
71
+
72
+
73
+ class SingleSpanFIM(AbstractFIM):
74
+
75
+ def __init__(self,
76
+ max_patches=5,
77
+ mask_fraction=0.2,
78
+ always_mask=False,
79
+ mask_tokens=MASK_TO_ID,
80
+ eos_token=AA_TO_ID["<eos>"],
81
+ add_position_ids=False,
82
+ troubleshoot=False):
83
+ super().__init__(max_patches, mask_fraction, always_mask, mask_tokens, eos_token, add_position_ids, troubleshoot)
84
+
85
+ def fim(self, sequences, t):
86
+ """
87
+ This function creates and concatenates parts of the sequences based on the OpenAI scrambling strategy.
88
+ It randomly selects two indices within the range of the given tuple,
89
+ splits the sequence into three parts based on these indices, and then concatenates them with the
90
+ masked patch at the end
91
+ """
92
+ new_tuple = tuple(np.sort(np.random.choice(np.arange(t[0] + 1, t[1]), 2, replace=False)))
93
+ part1 = sequences[t[0]:new_tuple[0]]
94
+ part2 = sequences[new_tuple[0]:new_tuple[1]]
95
+ part3 = sequences[new_tuple[1]:t[1]]
96
+ sequence = np.concatenate([part1, [self.mask_tokens["<mask-1>"]], part3, [self.mask_tokens["<mask-1>"]], part2])
97
+ position_ids_sequence = None
98
+ if self.add_position_ids:
99
+ position_ids = np.arange(t[0], t[1]) - t[0]
100
+ position_ids_part1 = position_ids[t[0]:new_tuple[0]]
101
+ position_ids_part2 = position_ids[new_tuple[0]:new_tuple[1]]
102
+ position_ids_part3 = position_ids[new_tuple[1]:t[1]]
103
+ position_ids_sequence = np.concatenate(
104
+ [position_ids_part1, [position_ids_part2[0]], position_ids_part3, [position_ids_part2[0]],
105
+ position_ids_part2])
106
+
107
+ return sequence, position_ids_sequence
108
+
109
+
110
+ class MultipleSpanFIM(AbstractFIM):
111
+ def __init__(self,
112
+ max_patches=5,
113
+ mask_fraction=0.2,
114
+ always_mask=False,
115
+ mask_tokens=MASK_TO_ID,
116
+ eos_token=AA_TO_ID["<eos>"],
117
+ add_position_ids=False,
118
+ troubleshoot=False):
119
+ super().__init__(max_patches, mask_fraction, always_mask, mask_tokens, eos_token, add_position_ids, troubleshoot)
120
+
121
+ def fim(self, sequences, t):
122
+ """
123
+ This function creates and concatenates parts of the sequences based on the inpaint scrambling strategy.
124
+ It randomly selects `2*num_patches` indices within the range of the given tuple,
125
+ splits the sequence into unmasked and masked parts based on these indices, and then concatenates them.
126
+ The number of patches is sampled from a poisson distribution with upper limit `self.max_patches` and average 1.
127
+ The concatenation is done by joining all unmaksed parts (interleaved with mask tokens) and afterwards
128
+ all masked parts (interleaved with mask tokens). At the end of the unmasked parts, a special token is added
129
+ to indicate the end of the unmasked parts, and at the end of the masked parts, a special token is added
130
+ to indicate the end of the masked parts.
131
+ """
132
+ # sample num_patches from a discrete poisson distribution with upper limit L
133
+ def sample_lengths(start, end):
134
+ """
135
+ Sample a length uniformly from 1 to max_L*self.mask_fraction (must be bigger than 1).
136
+ If the length is larger than max_L, return max_L.
137
+ """
138
+ max_L = end - start
139
+ length = np.random.randint(1, max(int(max_L * self.mask_fraction), 2))
140
+ return min(length, max_L)
141
+
142
+ # sample num_patches from a discrete poisson distribution with upper limit max_patches
143
+ num_patches = 1000
144
+ while num_patches > self.max_patches:
145
+ num_patches = np.random.poisson(1)
146
+ if self.always_mask:
147
+ num_patches = max(num_patches, 1)
148
+ # sample num_patches starting points for the masked positions (+ final position)
149
+ start_patches = list(np.sort(np.random.choice(np.arange(t[0] + 1, t[1]),
150
+ num_patches,
151
+ replace=False))) + [t[1]]
152
+ # sample num_patches lengths of the patches
153
+ len_patches = [sample_lengths(start_patches[i], start_patches[i + 1])
154
+ for i in range(len(start_patches) - 1)]
155
+ # create masked tuples with start and end indices of the patches
156
+ masked_tuples = [(start_patches[i], start_patches[i] + len_patches[i]) for i in range(len(start_patches) - 1)]
157
+ # split the sequences into unmasked and masked parts
158
+ unmasked_sequence, masked_sequence, unmasked_position_ids, masked_position_ids = self.split_sequences(sequences,
159
+ t,
160
+ masked_tuples)
161
+
162
+ if self.troubleshoot:
163
+ print(f"For sequence in {t}: sampled {num_patches=}, {start_patches=}, {len_patches=}, {masked_tuples=}")
164
+ # concatenate the unmasked and masked parts
165
+ return unmasked_sequence + masked_sequence, unmasked_position_ids + masked_position_ids if self.add_position_ids else None
166
+
167
+ def split_sequences(self, sequences, t, masked_tuples):
168
+ """
169
+ This function splits the sequences into unmasked and masked parts based on the given tuples.
170
+ Args:
171
+ t (tuple): The start and end index of each sequence.
172
+ masked_tuples (list): A list of tuples specifying the indices for masked regions.
173
+ Returns:
174
+ unmasked_parts (list): The unmasked parts of the sequences interleaved with mask_tokens.
175
+ masked_parts (list): The masked parts of the sequences interleaved with mask_tokens.
176
+ """
177
+ unmasked_parts, masked_parts = [], []
178
+ unmasked_positions, masked_positions = [], []
179
+ position_ids = None
180
+ start, end = t
181
+ if self.add_position_ids:
182
+ position_ids = np.arange(start, end) - start
183
+ for i, region in enumerate(masked_tuples):
184
+ mask_token = self.mask_tokens[f"<mask-{i + 1}>"]
185
+ unmasked_parts.extend(sequences[start:region[0]])
186
+ unmasked_parts.append(mask_token)
187
+ masked_parts.append(mask_token)
188
+ masked_parts.extend(sequences[region[0]:region[1]])
189
+ if self.add_position_ids:
190
+ unmasked_positions.extend(position_ids[start-t[0]:region[0]-t[0]])
191
+ unmasked_positions.append(position_ids[region[0]-t[0]])
192
+ masked_positions.append(position_ids[region[0]-t[0]])
193
+ masked_positions.extend(position_ids[region[0]-t[0]:region[1]-t[0]])
194
+
195
+ start = region[1]
196
+ unmasked_parts.extend(sequences[start:end])
197
+ if self.add_position_ids:
198
+ unmasked_positions.extend(position_ids[start-t[0]:end-t[0]])
199
+ if len(masked_tuples) > 0:
200
+ unmasked_parts.append(self.eos_token)
201
+ if self.add_position_ids:
202
+ unmasked_positions.append(0)
203
+ return unmasked_parts, masked_parts, unmasked_positions, masked_positions
protxlstm/generation.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Original code from ProtMamba under Apache License 2.0.
2
+ #
3
+ # Modifications made by Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen
4
+ # - Add option to pass input state for generation
5
+ # - Add functions to generate sequences with xlstm
6
+
7
+ import numpy as np
8
+ import torch
9
+ from protxlstm.mamba_utils_generation import (
10
+ InferenceParams,
11
+ GenerationMixin,
12
+ GreedySearchDecoderOnlyOutput,
13
+ modify_logits_for_top_p_filtering,
14
+ modify_logits_for_min_p_filtering,
15
+ modify_logit_for_repetition_penalty,
16
+ SampleDecoderOnlyOutput,
17
+ update_graph_cache
18
+ )
19
+
20
+ from protxlstm.utils import AA_TO_ID, decode_sequence
21
+
22
+ def sample_safe(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0):
23
+ """Sample from top-k logits.
24
+ Arguments:
25
+ logits: Tensor of shape (batch_size, vocab_size)
26
+ """
27
+ if top_k == 1: # Short-circuit for greedy decoding
28
+ return logits.argmax(dim=-1)
29
+ else:
30
+ if top_p > 0.0:
31
+ assert top_p <= 1.0, "top-p should be in (0, 1]."
32
+ if top_k > 0:
33
+ top_k = min(top_k, logits.size(-1)) # Safety check
34
+ logits_top, indices = torch.topk(logits, top_k, dim=-1)
35
+ if temperature != 1.0:
36
+ logits_top /= temperature
37
+ modify_logits_for_top_p_filtering(logits_top, top_p)
38
+
39
+ return indices[
40
+ torch.arange(indices.shape[0], device=indices.device),
41
+ torch.multinomial(
42
+ torch.softmax(logits_top, dim=-1), num_samples=1
43
+ ).squeeze(dim=-1),
44
+ ]
45
+ else:
46
+ if min_p > 0.0:
47
+ logits_top = logits.clone()
48
+ max_prob = logits_top[..., 0].item()
49
+ min_prob = max_prob * min_p
50
+ modify_logits_for_min_p_filtering(logits_top, min_p)
51
+ if temperature != 1.0:
52
+ logits_top /= temperature
53
+ return torch.multinomial(
54
+ torch.softmax(logits_top, dim=-1), num_samples=1
55
+ ).squeeze(dim=-1)
56
+ # Clone so that when we modify for top_p we don't change the original logits
57
+ logits_top = logits / temperature if temperature != 1.0 else logits.clone()
58
+ modify_logits_for_top_p_filtering(logits_top, top_p)
59
+ return torch.multinomial(
60
+ torch.softmax(logits_top, dim=-1), num_samples=1
61
+ ).squeeze(dim=-1)
62
+
63
+
64
+ @torch.inference_mode()
65
+ def decode_safe(
66
+ input_ids,
67
+ position_ids,
68
+ seq_position_ids,
69
+ is_fim,
70
+ model,
71
+ max_length,
72
+ state=None,
73
+ top_k=1,
74
+ top_p=0.0,
75
+ min_p=0.0,
76
+ temperature=1.0,
77
+ repetition_penalty=1.0,
78
+ eos_token_id=None,
79
+ teacher_outputs=None,
80
+ vocab_size=None,
81
+ cg=False,
82
+ enable_timing=False,
83
+ streamer = None,
84
+ chunk_chunk_size = 2**15,
85
+ ):
86
+ """Decoding, either greedy or with top-k or top-p sampling.
87
+ If top-k = 0, don't limit the number of candidates (pure sampling).
88
+ Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
89
+ then top-p.
90
+ We assume that all sequences in the same batch have the same length.
91
+
92
+ Arguments:
93
+ input_ids: (batch, seq_len)
94
+ max_length: int
95
+ is_fim: dictionary with mask indices and associated position indices
96
+ teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
97
+ logits, the next token is taken from the teacher_outputs. Useful for testing.
98
+ Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
99
+ sequences: (batch, max_length)
100
+ scores: tuples of (batch, vocab_size)
101
+ """
102
+ if streamer is not None:
103
+ streamer.put(input_ids.cpu())
104
+
105
+ batch_size, seqlen_og = input_ids.shape
106
+ teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
107
+ if cg:
108
+ if not hasattr(model, "_decoding_cache"):
109
+ model._decoding_cache = None
110
+ model._decoding_cache = update_graph_cache(
111
+ model,
112
+ model._decoding_cache,
113
+ batch_size,
114
+ seqlen_og,
115
+ max_length,
116
+ )
117
+ inference_params = model._decoding_cache.inference_params
118
+ inference_params.reset(max_length, batch_size)
119
+ else:
120
+ inference_params = InferenceParams(
121
+ max_seqlen=max_length, max_batch_size=batch_size
122
+ )
123
+
124
+ def get_logits(input_ids, position_ids, seq_position_ids, inference_params):
125
+ decoding = inference_params.seqlen_offset > 0
126
+ if not cg or not decoding:
127
+ logits = model(
128
+ input_ids,
129
+ position_ids=position_ids,
130
+ seq_position_ids=seq_position_ids,
131
+ inference_params=inference_params,
132
+ num_last_tokens=1,
133
+ ).logits.squeeze(dim=1)
134
+ else:
135
+ logits = model._decoding_cache.run(
136
+ input_ids,
137
+ position_ids,
138
+ inference_params.seqlen_offset,
139
+ seq_position_ids=seq_position_ids,
140
+ ).squeeze(dim=1)
141
+ return logits[..., :vocab_size] if vocab_size is not None else logits
142
+
143
+ def get_xlstm_logits_step(input_ids, position_ids, seq_position_ids, state):
144
+
145
+ if not input_ids.shape[1] == 1:
146
+
147
+ for i in range(input_ids.shape[1]):
148
+ if position_ids != None:
149
+ token_position_ids = position_ids[:,i:(i+1)]
150
+ else:
151
+ token_position_ids = None
152
+ if seq_position_ids != None:
153
+ token_seq_position_ids = seq_position_ids[:,i:(i+1)]
154
+ else:
155
+ token_seq_position_ids = None
156
+ logits, state = model.step(input_ids[:,i:(i+1)], state, position_ids=token_position_ids, seq_position_ids=token_seq_position_ids)
157
+
158
+ else:
159
+ logits, state = model.step(input_ids, state, position_ids=position_ids, seq_position_ids=seq_position_ids)
160
+
161
+ logits = logits.squeeze(dim=1)
162
+ if vocab_size is not None:
163
+ logits = logits[..., :vocab_size]
164
+
165
+ return logits, state
166
+
167
+ def get_xlstm_logits_chunkwise(input_ids, position_ids, seq_position_ids, chunk_chunk_size=2**15, state=None):
168
+
169
+ assert model.config.config_dataclass.mlstm_block.mlstm.backend == "chunkwise_variable"
170
+
171
+ for chunk in range(input_ids.shape[1]//chunk_chunk_size+1):
172
+
173
+ start_idx = chunk*chunk_chunk_size
174
+ end_idx = min((chunk+1)*chunk_chunk_size, input_ids.shape[1])
175
+
176
+ if start_idx == end_idx:
177
+ pass
178
+
179
+ else:
180
+ input_ids_chunk = input_ids[:, start_idx:end_idx]
181
+
182
+ if not position_ids == None:
183
+ position_ids_chunk = position_ids[:, start_idx:end_idx]
184
+ else:
185
+ position_ids_chunk = None
186
+
187
+ if not seq_position_ids == None:
188
+ seq_position_ids_chunk = seq_position_ids[:, start_idx:end_idx]
189
+ else:
190
+ seq_position_ids_chunk = None
191
+
192
+ outputs = model(input_ids_chunk, position_ids=position_ids_chunk, seq_position_ids=seq_position_ids_chunk, state=state)
193
+ logits, state = outputs.logits, outputs.state
194
+
195
+ logits = logits[:,-1,:]
196
+ logits = logits.squeeze(dim=1)
197
+ if vocab_size is not None:
198
+ logits = logits[..., :vocab_size]
199
+
200
+ return logits, state
201
+
202
+ def sample_tokens(logits, inference_params):
203
+ if (
204
+ teacher_outputs is None
205
+ or teacher_output_len <= inference_params.seqlen_offset
206
+ ):
207
+ token = sample_safe(
208
+ logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature
209
+ )
210
+ else:
211
+ token = teacher_outputs[:, inference_params.seqlen_offset]
212
+ # return rearrange(token, "b -> b 1")
213
+ return token.unsqueeze(1)
214
+
215
+ def get_fim_position_id(
216
+ last_position_ids, sampled_tokens, is_fim, repeat_next=False
217
+ ):
218
+ if type(is_fim) is dict:
219
+ val = int(last_position_ids) + 1
220
+ should_repeat_next = False
221
+ if is_fim and int(sampled_tokens) in is_fim:
222
+ val = is_fim[int(sampled_tokens)]
223
+ should_repeat_next = True
224
+ elif repeat_next:
225
+ val = int(last_position_ids)
226
+ return torch.full_like(last_position_ids, fill_value=val), should_repeat_next
227
+ else:
228
+ t = [get_fim_position_id(last_position_ids_, sampled_tokens_, is_fim_dict, repeat_next) for
229
+ (last_position_ids_, sampled_tokens_, is_fim_dict) in
230
+ zip(last_position_ids, sampled_tokens, is_fim)]
231
+ return torch.stack([t_[0] for t_ in t], dim=0), t[0][1]
232
+
233
+ def should_stop(current_token, inference_params):
234
+ if inference_params.seqlen_offset == 0:
235
+ return False
236
+ if eos_token_id is not None and (current_token == eos_token_id).any():
237
+ if current_token.shape[1] > 1:
238
+ raise NotImplementedError("Batched eos_token_id not supported")
239
+ return True
240
+ if inference_params.seqlen_offset >= max_length - 1:
241
+ return True
242
+ return False
243
+
244
+ start = torch.cuda.Event(enable_timing=enable_timing)
245
+ end = torch.cuda.Event(enable_timing=enable_timing)
246
+
247
+ if enable_timing:
248
+ start.record()
249
+ scores, sequences = [], [input_ids]
250
+ new_position_ids, new_seq_position_ids = [position_ids], [seq_position_ids]
251
+ sequences_cat = input_ids
252
+ repeat_next = False
253
+ if position_ids.shape[0] > 1:
254
+ raise NotImplementedError("Batched generation with position_ids not supported")
255
+
256
+ encode_context=True
257
+ while not should_stop(sequences[-1], inference_params):
258
+
259
+ from protxlstm.models.xlstm import xLSTMLMHeadModel
260
+ if isinstance(model, xLSTMLMHeadModel):
261
+ if encode_context:
262
+ with torch.no_grad():
263
+ logits, state = get_xlstm_logits_chunkwise(sequences[-1], new_position_ids[-1], new_seq_position_ids[-1], state=state, chunk_chunk_size=chunk_chunk_size)
264
+ encode_context = False
265
+ else:
266
+ logits, state = get_xlstm_logits_step(sequences[-1], new_position_ids[-1], new_seq_position_ids[-1], state=state)
267
+ else:
268
+ logits = get_logits(sequences[-1], new_position_ids[-1], new_seq_position_ids[-1], inference_params)
269
+
270
+ scores.append(logits)
271
+
272
+ inference_params.seqlen_offset += sequences[-1].shape[1]
273
+ if repetition_penalty == 1.0:
274
+ sampled_tokens = sample_tokens(scores[-1], inference_params)
275
+ else:
276
+ logits = modify_logit_for_repetition_penalty(
277
+ scores[-1].clone(), sequences_cat, repetition_penalty
278
+ )
279
+ sampled_tokens = sample_tokens(logits, inference_params)
280
+ sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)
281
+ sequences.append(sampled_tokens)
282
+ # Update position_ids
283
+ if position_ids is not None:
284
+ last_position_ids, repeat_next = get_fim_position_id(
285
+ new_position_ids[-1][:, -1:], sampled_tokens, is_fim, repeat_next
286
+ )
287
+ new_position_ids.append(last_position_ids)
288
+ # Update seq_position_ids
289
+ if seq_position_ids is not None:
290
+ new_seq_position_ids.append(new_seq_position_ids[-1][:, -1:])
291
+
292
+ if streamer is not None:
293
+ streamer.put(sampled_tokens.cpu())
294
+ if streamer is not None:
295
+ streamer.end()
296
+ if enable_timing:
297
+ end.record()
298
+ torch.cuda.synchronize()
299
+ print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
300
+ output_cls = (
301
+ GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
302
+ )
303
+ return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
304
+
305
+
306
+ class GenerationMixinSafe(GenerationMixin):
307
+
308
+ def generate(
309
+ self,
310
+ input_ids,
311
+ position_ids,
312
+ seq_position_ids,
313
+ is_fim=None,
314
+ state=None,
315
+ max_length=1,
316
+ top_k=1,
317
+ top_p=0.0,
318
+ min_p=0.0,
319
+ temperature=1.0,
320
+ return_dict_in_generate=False,
321
+ output_scores=False,
322
+ chunk_chunk_size=2**15,
323
+ **kwargs,
324
+ ):
325
+
326
+ output = decode_safe(
327
+ input_ids,
328
+ position_ids,
329
+ seq_position_ids,
330
+ is_fim,
331
+ self,
332
+ max_length,
333
+ state=state,
334
+ top_k=top_k,
335
+ top_p=top_p,
336
+ min_p=min_p,
337
+ temperature=temperature,
338
+ chunk_chunk_size=chunk_chunk_size,
339
+ **kwargs,
340
+ )
341
+ if not output_scores:
342
+ output.scores = None
343
+ return output if return_dict_in_generate else output.sequences
344
+
345
+
346
+ def generate_sequence(model, tokens, position_ids=None, seq_position_ids=None, state=None, is_fim=False, max_length=2000, temperature=1., top_p=0.0, top_k=1,
347
+ return_dict_in_generate=False, output_scores=False, eos_token_id=AA_TO_ID["<cls>"], device="cuda", chunk_chunk_size=2**15):
348
+ """Generating, either greedy or with top-k or top-p sampling.
349
+ If top-k = 0, don't limit the number of candidates (pure sampling).
350
+ Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
351
+ then top-p. We assume that all sequences in the same batch have the same length.
352
+ """
353
+ input_ids = tokens.to(device)
354
+ position_ids = position_ids.to(device) if position_ids is not None else None
355
+ seq_position_ids = seq_position_ids.to(device) if seq_position_ids is not None else None
356
+ # generate sequence
357
+ out = model.generate(input_ids=input_ids,
358
+ position_ids=position_ids,
359
+ seq_position_ids=seq_position_ids,
360
+ is_fim=is_fim,
361
+ state=state,
362
+ max_length=max_length,
363
+ temperature=temperature,
364
+ top_p=top_p,
365
+ top_k=top_k,
366
+ return_dict_in_generate=return_dict_in_generate,
367
+ output_scores=output_scores,
368
+ eos_token_id=eos_token_id,
369
+ chunk_chunk_size=chunk_chunk_size,
370
+ )
371
+ sequences = out.sequences
372
+ dic = {"input": [decode_sequence(seq) for seq in sequences[:, :input_ids.shape[-1]].cpu().numpy()],
373
+ "generated": [decode_sequence(seq) for seq in sequences[:, input_ids.shape[-1]:].cpu().numpy()],
374
+ "input_tokens": [seq for seq in sequences[:, :input_ids.shape[-1]].cpu().numpy()],
375
+ "generated_tokens": [seq for seq in sequences[:, input_ids.shape[-1]:].cpu().numpy()]}
376
+ if output_scores:
377
+ dic["scores"] = np.array([el.to(torch.float32).cpu().numpy() for el in out.scores]).transpose(1, 0, 2)
378
+ return dic
379
+
380
+
381
+
382
+
383
+
384
+
protxlstm/index.html ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 3.2 Final//EN">
2
+ <html>
3
+ <head>
4
+ <title>Index of /research/Bio-xLSTM/downloads/Prot-xLSTM/checkpoints/protxlstm_26M_30B</title>
5
+ </head>
6
+ <body>
7
+ <h1>Index of /research/Bio-xLSTM/downloads/Prot-xLSTM/checkpoints/protxlstm_26M_30B</h1>
8
+ <pre><img src="/icons/blank.gif" alt="Icon "> <a href="?C=N;O=D">Name</a> <a href="?C=M;O=A">Last modified</a> <a href="?C=S;O=A">Size</a> <a href="?C=D;O=A">Description</a><hr><img src="/icons/back.gif" alt="[PARENTDIR]"> <a href="/research/Bio-xLSTM/downloads/Prot-xLSTM/checkpoints/">Parent Directory</a> -
9
+ <img src="/icons/unknown.gif" alt="[ ]"> <a href="config.json">config.json</a> 2024-11-04 14:36 1.8K
10
+ <img src="/icons/unknown.gif" alt="[ ]"> <a href="optimizer.pt">optimizer.pt</a> 2024-11-04 14:36 198M
11
+ <img src="/icons/binary.gif" alt="[ ]"> <a href="pytorch_model.bin">pytorch_model.bin</a> 2024-11-04 14:36 99M
12
+ <img src="/icons/unknown.gif" alt="[ ]"> <a href="rng_state.pth">rng_state.pth</a> 2024-11-04 14:36 14K
13
+ <img src="/icons/unknown.gif" alt="[ ]"> <a href="scheduler.pt">scheduler.pt</a> 2024-11-04 14:36 1.0K
14
+ <img src="/icons/unknown.gif" alt="[ ]"> <a href="trainer_state.json">trainer_state.json</a> 2024-11-04 14:36 2.4M
15
+ <hr></pre>
16
+ </body></html>
protxlstm/mamba_utils_generation.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/utils/generation.py
2
+ # Copyright (c) 2023, Albert Gu, Tri Dao.
3
+ import gc
4
+ from dataclasses import dataclass, field
5
+ from typing import Callable, Optional
6
+
7
+ import torch
8
+ from torch import Tensor
9
+ from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer
10
+
11
+
12
+ @dataclass
13
+ class InferenceParams:
14
+ """Inference parameters that are passed to the main model in order
15
+ to efficienly calculate and store the context during inference."""
16
+
17
+ max_seqlen: int
18
+ max_batch_size: int
19
+ seqlen_offset: int = 0
20
+ batch_size_offset: int = 0
21
+ key_value_memory_dict: dict = field(default_factory=dict)
22
+ lengths_per_sample: Optional[Tensor] = None
23
+
24
+ def reset(self, max_seqlen, max_batch_size):
25
+ self.max_seqlen = max_seqlen
26
+ self.max_batch_size = max_batch_size
27
+ self.seqlen_offset = 0
28
+ if self.lengths_per_sample is not None:
29
+ self.lengths_per_sample.zero_()
30
+
31
+
32
+ def modify_logits_for_min_p_filtering(logits, min_p):
33
+ """Set the logits for none min_p values to -inf. Done in-place."""
34
+ if min_p <= 0.0 or min_p >= 1.0:
35
+ return
36
+ indices_to_remove = logits < min_p
37
+ logits.masked_fill_(indices_to_remove, float("-Inf"))
38
+ # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
39
+ # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
40
+ def modify_logits_for_top_k_filtering(logits, top_k):
41
+ """Set the logits for none top-k values to -inf. Done in-place."""
42
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
43
+ logits.masked_fill_(indices_to_remove, float("-Inf"))
44
+
45
+
46
+ # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
47
+ # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
48
+ def modify_logits_for_top_p_filtering(logits, top_p):
49
+ """Set the logits for none top-p values to -inf. Done in-place."""
50
+ if top_p <= 0.0 or top_p >= 1.0:
51
+ return
52
+ # First sort and calculate cumulative sum of probabilities.
53
+ sorted_logits, sorted_indices = torch.sort(logits, descending=False)
54
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
55
+ # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
56
+ sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
57
+ # scatter sorted tensors to original indexing
58
+ indices_to_remove = sorted_indices_to_remove.scatter(
59
+ 1, sorted_indices, sorted_indices_to_remove
60
+ )
61
+ logits.masked_fill_(indices_to_remove, float("-inf"))
62
+
63
+
64
+ def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_penalty=1.0):
65
+ """Apply repetition penalty. See https://arxiv.org/abs/1909.05858
66
+ logits: (batch_size, vocab_size)
67
+ prev_output_tokens: (batch_size, seq_len)
68
+ """
69
+ if repetition_penalty == 1.0:
70
+ return logits
71
+ score = torch.gather(logits, 1, prev_output_tokens)
72
+ # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
73
+ score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
74
+ logits.scatter_(1, prev_output_tokens, score)
75
+ return logits
76
+
77
+
78
+ def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0):
79
+ """Sample from top-k logits.
80
+ Arguments:
81
+ logits: Tensor of shape (batch_size, vocab_size)
82
+ """
83
+ if top_k == 1: # Short-circuit for greedy decoding
84
+ return logits.argmax(dim=-1)
85
+ else:
86
+ if top_p > 0.0:
87
+ assert top_p <= 1.0, "top-p should be in (0, 1]."
88
+ if top_k > 0:
89
+ top_k = min(top_k, logits.size(-1)) # Safety check
90
+ logits_top, indices = torch.topk(logits, top_k, dim=-1)
91
+ if temperature != 1.0:
92
+ logits_top /= temperature
93
+ modify_logits_for_top_p_filtering(logits_top, top_p)
94
+ return indices[
95
+ torch.arange(indices.shape[0], device=indices.device),
96
+ torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
97
+ ]
98
+ else:
99
+ if min_p > 0.0:
100
+ logits_top = logits.clone()
101
+ max_prob = logits_top[..., 0].item()
102
+ min_prob = max_prob * min_p
103
+ modify_logits_for_min_p_filtering(logits_top, min_prob)
104
+ if temperature != 1.0:
105
+ logits_top /= temperature
106
+ return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
107
+ # Clone so that when we modify for top_p we don't change the original logits
108
+ logits_top = logits / temperature if temperature != 1.0 else logits.clone()
109
+ modify_logits_for_top_p_filtering(logits_top, top_p)
110
+ return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
111
+ dim=-1
112
+ )
113
+
114
+
115
+ @torch.inference_mode()
116
+ def decode(
117
+ input_ids,
118
+ model,
119
+ max_length,
120
+ top_k=1,
121
+ top_p=0.0,
122
+ min_p=0.0,
123
+ temperature=1.0,
124
+ repetition_penalty=1.0,
125
+ eos_token_id=None,
126
+ teacher_outputs=None,
127
+ vocab_size=None,
128
+ cg=False,
129
+ enable_timing=False,
130
+ streamer: Optional[TextStreamer] = None
131
+ ):
132
+ """Decoding, either greedy or with top-k or top-p sampling.
133
+ If top-k = 0, don't limit the number of candidates (pure sampling).
134
+ Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
135
+ then top-p.
136
+ We assume that all sequences in the same batch have the same length.
137
+
138
+ Arguments:
139
+ input_ids: (batch, seq_len)
140
+ max_length: int
141
+ teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
142
+ logits, the next token is taken from the teacher_outputs. Useful for testing.
143
+ Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
144
+ sequences: (batch, max_length)
145
+ scores: tuples of (batch, vocab_size)
146
+ """
147
+ if streamer is not None:
148
+ streamer.put(input_ids.cpu())
149
+
150
+ batch_size, seqlen_og = input_ids.shape
151
+ teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
152
+ if cg:
153
+ if not hasattr(model, "_decoding_cache"):
154
+ model._decoding_cache = None
155
+ model._decoding_cache = update_graph_cache(
156
+ model,
157
+ model._decoding_cache,
158
+ batch_size,
159
+ seqlen_og,
160
+ max_length,
161
+ )
162
+ inference_params = model._decoding_cache.inference_params
163
+ inference_params.reset(max_length, batch_size)
164
+ else:
165
+ inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
166
+
167
+ def get_logits(input_ids, inference_params):
168
+ decoding = inference_params.seqlen_offset > 0
169
+ if decoding:
170
+ position_ids = torch.full(
171
+ (batch_size, 1),
172
+ inference_params.seqlen_offset,
173
+ dtype=torch.long,
174
+ device=input_ids.device,
175
+ )
176
+ else:
177
+ position_ids = None
178
+ if not cg or not decoding:
179
+ logits = model(
180
+ input_ids,
181
+ position_ids=position_ids,
182
+ inference_params=inference_params,
183
+ num_last_tokens=1,
184
+ ).logits.squeeze(dim=1)
185
+ else:
186
+ logits = model._decoding_cache.run(
187
+ input_ids, position_ids, inference_params.seqlen_offset
188
+ ).squeeze(dim=1)
189
+ return logits[..., :vocab_size] if vocab_size is not None else logits
190
+
191
+ def sample_tokens(logits, inference_params):
192
+ if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
193
+ token = sample(logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature)
194
+ else:
195
+ token = teacher_outputs[:, inference_params.seqlen_offset]
196
+ # return rearrange(token, "b -> b 1")
197
+ return token.unsqueeze(1)
198
+
199
+ def should_stop(current_token, inference_params):
200
+ if inference_params.seqlen_offset == 0:
201
+ return False
202
+ if eos_token_id is not None and (current_token == eos_token_id).all():
203
+ return True
204
+ if inference_params.seqlen_offset >= max_length - 1:
205
+ return True
206
+ return False
207
+
208
+ start = torch.cuda.Event(enable_timing=enable_timing)
209
+ end = torch.cuda.Event(enable_timing=enable_timing)
210
+
211
+ if enable_timing:
212
+ start.record()
213
+ scores, sequences = [], [input_ids]
214
+ sequences_cat = input_ids
215
+ while not should_stop(sequences[-1], inference_params):
216
+ scores.append(get_logits(sequences[-1], inference_params))
217
+ inference_params.seqlen_offset += sequences[-1].shape[1]
218
+ if repetition_penalty == 1.0:
219
+ sampled_tokens = sample_tokens(scores[-1], inference_params)
220
+ else:
221
+ logits = modify_logit_for_repetition_penalty(
222
+ scores[-1].clone(), sequences_cat, repetition_penalty
223
+ )
224
+ sampled_tokens = sample_tokens(logits, inference_params)
225
+ sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)
226
+ sequences.append(sampled_tokens)
227
+ if streamer is not None:
228
+ streamer.put(sampled_tokens.cpu())
229
+ if streamer is not None:
230
+ streamer.end()
231
+ if enable_timing:
232
+ end.record()
233
+ torch.cuda.synchronize()
234
+ print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
235
+ output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
236
+ return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
237
+
238
+
239
+ class GenerationMixin:
240
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
241
+ raise NotImplementedError
242
+
243
+ def generate(
244
+ self,
245
+ input_ids,
246
+ max_length,
247
+ top_k=1,
248
+ top_p=0.0,
249
+ min_p=0.0,
250
+ temperature=1.0,
251
+ return_dict_in_generate=False,
252
+ output_scores=False,
253
+ **kwargs,
254
+ ):
255
+ output = decode(
256
+ input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, **kwargs
257
+ )
258
+ if not output_scores:
259
+ output.scores = None
260
+ return output if return_dict_in_generate else output.sequences
261
+
262
+
263
+ @dataclass
264
+ class DecodingCGCache:
265
+ max_batch_size: int = 0
266
+ max_seqlen: int = 0
267
+ device = None
268
+ dtype = None
269
+ callables: dict = field(default_factory=dict)
270
+ mempool = None
271
+ inference_params: Optional[InferenceParams] = None
272
+ run: Optional[Callable] = None
273
+
274
+
275
+ @torch.inference_mode()
276
+ def update_graph_cache(
277
+ model,
278
+ cache,
279
+ batch_size,
280
+ seqlen_og,
281
+ max_seqlen,
282
+ decoding_seqlens=(1,),
283
+ dtype=None,
284
+ n_warmups=2,
285
+ ):
286
+ if cache is None:
287
+ cache = DecodingCGCache()
288
+ param_example = next(iter(model.parameters()))
289
+ device = param_example.device
290
+ if dtype is None:
291
+ dtype = param_example.dtype
292
+ if (
293
+ (device, dtype) != (cache.device, cache.dtype)
294
+ or batch_size > cache.max_batch_size
295
+ or max_seqlen > cache.max_seqlen
296
+ ): # Invalidate the cache
297
+ cache.callables = {}
298
+ cache.mempool = None
299
+ cache.inference_params = None
300
+ gc.collect()
301
+ cache.device, cache.dtype = device, dtype
302
+ cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
303
+ assert hasattr(model, "allocate_inference_cache"), "CUDA graph decoding requires that the model has a method allocate_inference_cache"
304
+ inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
305
+ lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
306
+ cache.inference_params = InferenceParams(
307
+ max_seqlen=max_seqlen,
308
+ max_batch_size=batch_size,
309
+ seqlen_offset=seqlen_og,
310
+ key_value_memory_dict=inf_cache,
311
+ lengths_per_sample=lengths_per_sample,
312
+ )
313
+ cache.mempool = torch.cuda.graphs.graph_pool_handle()
314
+ for decoding_seqlen in decoding_seqlens:
315
+ if (batch_size, decoding_seqlen) not in cache.callables:
316
+ cache.callables[batch_size, decoding_seqlen] = capture_graph(
317
+ model,
318
+ cache.inference_params,
319
+ batch_size,
320
+ max_seqlen,
321
+ decoding_seqlen=decoding_seqlen,
322
+ mempool=cache.mempool,
323
+ n_warmups=n_warmups,
324
+ )
325
+
326
+ def dispatch(input_ids, position_ids, seqlen):
327
+ batch_size, decoding_seqlen = input_ids.shape[:2]
328
+ return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen)
329
+
330
+ cache.run = dispatch
331
+ cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing
332
+ return cache
333
+
334
+
335
+ def capture_graph(
336
+ model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2
337
+ ):
338
+ device = next(iter(model.parameters())).device
339
+ input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
340
+ position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
341
+ seqlen_offset_og = inference_params.seqlen_offset
342
+ inference_params.seqlen_offset = max_seqlen - decoding_seqlen
343
+ inference_params.lengths_per_sample[:] = inference_params.seqlen_offset
344
+
345
+ # Warmup before capture
346
+ s = torch.cuda.Stream()
347
+ s.wait_stream(torch.cuda.current_stream())
348
+ with torch.cuda.stream(s):
349
+ for _ in range(n_warmups):
350
+ logits = model(
351
+ input_ids,
352
+ position_ids=position_ids,
353
+ inference_params=inference_params,
354
+ num_last_tokens=decoding_seqlen,
355
+ ).logits
356
+ s.synchronize()
357
+ # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
358
+ # which requires that graph launch and non-captured launch to not overlap (I think,
359
+ # that's how I interpret the documentation). I'm not sure if this is required.
360
+ if torch.distributed.is_initialized():
361
+ torch.distributed.barrier()
362
+ torch.cuda.current_stream().wait_stream(s)
363
+ # Captures the graph
364
+ # To allow capture, automatically sets a side stream as the current stream in the context
365
+ graph = torch.cuda.CUDAGraph()
366
+ with torch.cuda.graph(graph, pool=mempool):
367
+ logits = model(
368
+ input_ids,
369
+ position_ids=position_ids,
370
+ inference_params=inference_params,
371
+ num_last_tokens=decoding_seqlen,
372
+ ).logits
373
+
374
+ def run(new_input_ids, new_position_ids, seqlen):
375
+ inference_params.lengths_per_sample[:] = seqlen
376
+ input_ids.copy_(new_input_ids)
377
+ position_ids.copy_(new_position_ids)
378
+ graph.replay()
379
+ return logits.clone()
380
+
381
+ inference_params.seqlen_offset = seqlen_offset_og
382
+ return run
protxlstm/models/__init__.py ADDED
File without changes
protxlstm/models/llama.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import os
4
+ from collections import namedtuple
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from transformers import PretrainedConfig
11
+
12
+ from protxlstm.xlstm.components.rotary_position import compute_freqs_cis
13
+
14
+ # Note: generation capabilities are not implemented for the transformer
15
+
16
+ class TransformerConfig(PretrainedConfig):
17
+
18
+ model_type = "llama"
19
+
20
+ def __init__(
21
+ self,
22
+ d_model,
23
+ n_layer,
24
+ n_heads,
25
+ n_kv_heads,
26
+ bidirectional,
27
+ vocab_size,
28
+ hidden_dim,
29
+ multiple_of, # MLP hidden layer size will be multiple of
30
+ norm_eps,
31
+ max_length,
32
+ dropout,
33
+ max_position_embeddings,
34
+ rope_base_frequency,
35
+ **kwargs
36
+ ):
37
+ super().__init__(**kwargs)
38
+
39
+ # default hyperparameters for the Llama 7B model
40
+ self.dim = d_model
41
+ self.n_layers = n_layer
42
+ self.n_heads = n_heads
43
+ self.n_kv_heads = n_kv_heads
44
+ self.causal_attention = not bidirectional
45
+ self.vocab_size = vocab_size
46
+ self.hidden_dim = hidden_dim
47
+ self.multiple_of = multiple_of
48
+ self.norm_eps = norm_eps
49
+ self.max_seq_len = max_length
50
+ self.dropout = dropout
51
+ self.max_position_embeddings = max_position_embeddings
52
+ self.rope_base_frequency = rope_base_frequency
53
+
54
+ class RMSNorm_transformer(torch.nn.Module):
55
+ def __init__(self, dim: int, eps: float):
56
+ super().__init__()
57
+ self.eps = eps
58
+ self.weight = nn.Parameter(torch.ones(dim))
59
+
60
+ def _norm(self, x):
61
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
62
+
63
+ def forward(self, x):
64
+ output = self._norm(x.float()).type_as(x)
65
+ return output * self.weight
66
+
67
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
68
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
69
+ t = torch.arange(end, device=freqs.device) # type: ignore
70
+ freqs = torch.outer(t, freqs).float() # type: ignore
71
+ freqs_cos = torch.cos(freqs) # real part
72
+ freqs_sin = torch.sin(freqs) # imaginary part
73
+ return freqs_cos, freqs_sin
74
+
75
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
76
+ ndim = x.ndim
77
+ assert 0 <= 1 < ndim
78
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
79
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
80
+ return freqs_cis.view(shape)
81
+
82
+ def apply_rotary_emb(
83
+ xq: torch.Tensor,
84
+ xk: torch.Tensor,
85
+ freqs_cos: torch.Tensor,
86
+ freqs_sin: torch.Tensor
87
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
88
+
89
+ # reshape xq and xk to match the complex representation
90
+ xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
91
+ xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
92
+
93
+ # reshape freqs_cos and freqs_sin for broadcasting
94
+ freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
95
+ freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
96
+
97
+ # apply rotation using real numbers
98
+ xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
99
+ xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
100
+ xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
101
+ xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
102
+
103
+ # flatten last two dimensions
104
+ xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
105
+ xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
106
+
107
+ return xq_out.type_as(xq), xk_out.type_as(xk)
108
+
109
+ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
110
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
111
+ bs, slen, n_kv_heads, head_dim = x.shape
112
+ if n_rep == 1:
113
+ return x
114
+ return (
115
+ x[:, :, :, None, :]
116
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
117
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
118
+ )
119
+
120
+ class Attention(nn.Module):
121
+ def __init__(self, args: TransformerConfig):
122
+ super().__init__()
123
+ self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
124
+ assert args.n_heads % self.n_kv_heads == 0
125
+ model_parallel_size = 1
126
+ self.n_local_heads = args.n_heads // model_parallel_size
127
+ self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
128
+ self.n_rep = self.n_local_heads // self.n_local_kv_heads
129
+ self.head_dim = args.dim // args.n_heads
130
+ self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
131
+ self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
132
+ self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
133
+ self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
134
+ self.attn_dropout = nn.Dropout(args.dropout)
135
+ self.resid_dropout = nn.Dropout(args.dropout)
136
+ self.dropout = args.dropout
137
+ self.causal_attention = args.causal_attention
138
+
139
+ # use flash attention or a manual implementation?
140
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
141
+ if not self.flash and self.causal_attention:
142
+ print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
143
+ mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
144
+ mask = torch.triu(mask, diagonal=1)
145
+ self.register_buffer("mask", mask)
146
+
147
+ def forward(
148
+ self,
149
+ x: torch.Tensor,
150
+ freqs_cos: torch.Tensor,
151
+ freqs_sin: torch.Tensor,
152
+ ):
153
+ bsz, seqlen, _ = x.shape
154
+
155
+ # QKV
156
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
157
+ xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
158
+ xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
159
+ xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
160
+
161
+ # RoPE relative positional embeddings
162
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
163
+
164
+ # grouped multiquery attention: expand out keys and values
165
+ xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
166
+ xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
167
+
168
+ # make heads into a batch dimension
169
+ xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
170
+ xk = xk.transpose(1, 2)
171
+ xv = xv.transpose(1, 2)
172
+
173
+ # flash implementation
174
+ if self.flash:
175
+ output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=self.causal_attention)
176
+ else:
177
+ # manual implementation
178
+ scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
179
+ if self.causal_attention:
180
+ scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
181
+ scores = F.softmax(scores.float(), dim=-1).type_as(xq)
182
+ scores = self.attn_dropout(scores)
183
+ output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)
184
+
185
+ # restore time as batch dimension and concat heads
186
+ output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
187
+
188
+ # final projection into the residual stream
189
+ output = self.wo(output)
190
+ output = self.resid_dropout(output)
191
+ return output
192
+
193
+ class FeedForward(nn.Module):
194
+ def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
195
+ super().__init__()
196
+ if hidden_dim is None:
197
+ hidden_dim = 4 * dim
198
+ hidden_dim = int(2 * hidden_dim / 3)
199
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
200
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
201
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
202
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
203
+ self.dropout = nn.Dropout(dropout)
204
+
205
+ def forward(self, x):
206
+ return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
207
+
208
+ class TransformerBlock(nn.Module):
209
+ def __init__(self, layer_id: int, args: TransformerConfig):
210
+ super().__init__()
211
+ self.n_heads = args.n_heads
212
+ self.dim = args.dim
213
+ self.head_dim = args.dim // args.n_heads
214
+ self.attention = Attention(args)
215
+ self.feed_forward = FeedForward(
216
+ dim=args.dim,
217
+ hidden_dim=args.hidden_dim,
218
+ multiple_of=args.multiple_of,
219
+ dropout=args.dropout,
220
+ )
221
+ self.layer_id = layer_id
222
+ self.attention_norm = RMSNorm_transformer(args.dim, eps=args.norm_eps)
223
+ self.ffn_norm = RMSNorm_transformer(args.dim, eps=args.norm_eps)
224
+
225
+ def forward(self, x, freqs_cos, freqs_sin):
226
+ h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
227
+ out = h + self.feed_forward.forward(self.ffn_norm(h))
228
+ return out
229
+
230
+ class Transformer(nn.Module):
231
+
232
+ last_loss: Optional[torch.Tensor]
233
+
234
+ def __init__(self, params: TransformerConfig):
235
+ super().__init__()
236
+ self.params = params
237
+ self.vocab_size = params.vocab_size
238
+ self.n_layers = params.n_layers
239
+
240
+ self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
241
+ self.dropout = nn.Dropout(params.dropout)
242
+ self.layers = torch.nn.ModuleList()
243
+ for layer_id in range(params.n_layers):
244
+ self.layers.append(TransformerBlock(layer_id, params))
245
+ self.layer_head_dim = self.layers[0].head_dim
246
+
247
+ self.norm = RMSNorm_transformer(params.dim, eps=params.norm_eps)
248
+ self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
249
+
250
+ # share the unembedding parameters with the embedding parameters
251
+ self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying
252
+
253
+ # some useful precompute for the RoPE relative positional embeddings
254
+ # freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)
255
+ # self.register_buffer("freqs_cos", freqs_cos, persistent=False)
256
+ # self.register_buffer("freqs_sin", freqs_sin, persistent=False)
257
+
258
+ # init all weights
259
+ self.apply(self._init_weights)
260
+ # apply special scaled init to the residual projections, per GPT-2 paper
261
+ for pn, p in self.named_parameters():
262
+ if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
263
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * params.n_layers))
264
+
265
+ # Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor.
266
+ self.last_loss = None
267
+
268
+ def _init_weights(self, module):
269
+ if isinstance(module, nn.Linear):
270
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
271
+ if module.bias is not None:
272
+ torch.nn.init.zeros_(module.bias)
273
+ elif isinstance(module, nn.Embedding):
274
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
275
+
276
+ def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
277
+ _bsz, seqlen = tokens.shape
278
+ h = self.tok_embeddings(tokens)
279
+ h = self.dropout(h)
280
+ # freqs_cos = self.freqs_cos[:seqlen]
281
+ # freqs_sin = self.freqs_sin[:seqlen]
282
+
283
+ if 'position_ids' in kwargs:
284
+ freqs_cos, freqs_sin = compute_freqs_cis(kwargs.pop("position_ids"), self.layer_head_dim, theta=self.params.rope_base_frequency)
285
+ else:
286
+ raise ValueError('Llama model only implemented with RoPEs')
287
+
288
+ freqs_cos = freqs_cos.squeeze()
289
+ freqs_sin = freqs_sin.squeeze()
290
+
291
+ for layer in self.layers:
292
+ h = layer(h, freqs_cos, freqs_sin)
293
+ h = self.norm(h)
294
+
295
+ if targets is not None:
296
+ logits = self.output(h)
297
+ self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
298
+ else:
299
+ logits = self.output(h)
300
+ self.last_loss = None
301
+
302
+ return logits
303
+
304
+ class TransformerLMHeadModel(nn.Module):
305
+
306
+ def __init__(
307
+ self,
308
+ config: TransformerConfig,
309
+ ) -> None:
310
+
311
+ super().__init__()
312
+
313
+ self.config = config
314
+
315
+ self.backbone = Transformer(config)
316
+
317
+ def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
318
+ """
319
+ num_last_tokens: if > 0, only return the logits for the last n tokens
320
+ """
321
+
322
+ lm_logits = self.backbone(input_ids, position_ids=position_ids)
323
+
324
+ CausalLMOutput = namedtuple("CausalLMOutput", ["loss", "logits"])
325
+ return CausalLMOutput(loss=None, logits=lm_logits)
326
+
327
+ def save_pretrained(self, save_directory):
328
+ """
329
+ Save the model and its configuration file to a directory.
330
+ """
331
+
332
+ # Ensure save_directory exists
333
+ os.makedirs(save_directory, exist_ok=True)
334
+
335
+ # Save the model's state_dict
336
+ model_path = os.path.join(save_directory, "pytorch_model.bin")
337
+ torch.save(self.state_dict(), model_path)
338
+
339
+ # Save the configuration of the model
340
+ config_path = os.path.join(save_directory, "config.json")
341
+ with open(config_path, "w") as f:
342
+ json.dump(self.config.to_dict(), f)
protxlstm/models/mamba.py ADDED
@@ -0,0 +1,833 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Original code from ProtMamba under Apache License 2.0.
2
+
3
+ import json
4
+ import os
5
+ from collections import namedtuple
6
+ from dataclasses import dataclass, field
7
+ from functools import partial
8
+
9
+ from mamba_ssm.models.config_mamba import MambaConfig
10
+ from mamba_ssm.modules.mamba_simple import Block, Mamba
11
+ from mamba_ssm.models.mixer_seq_simple import MixerModel, _init_weights
12
+ from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
13
+ from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.utils.checkpoint import checkpoint
17
+ from transformers import PretrainedConfig
18
+
19
+ from protxlstm.generation import GenerationMixinSafe
20
+
21
+ @dataclass
22
+ class MambaConfig(PretrainedConfig):
23
+ d_model: int = 2560
24
+ n_layer: int = 64
25
+ vocab_size: int = 50277
26
+ ssm_cfg: dict = field(default_factory=dict)
27
+ rms_norm: bool = True
28
+ residual_in_fp32: bool = True
29
+ fused_add_norm: bool = True
30
+ pad_vocab_size_multiple: int = 8
31
+ max_position_embeddings: int = 2048
32
+
33
+ def create_block(
34
+ d_model,
35
+ ssm_cfg=None,
36
+ norm_epsilon=1e-5,
37
+ rms_norm=False,
38
+ residual_in_fp32=False,
39
+ fused_add_norm=False,
40
+ layer_idx=None,
41
+ device=None,
42
+ dtype=None,
43
+ checkpoint_mixer=False,
44
+ ):
45
+ if ssm_cfg is None:
46
+ ssm_cfg = {}
47
+ factory_kwargs = {"device": device, "dtype": dtype}
48
+ mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
49
+ norm_cls = partial(
50
+ nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
51
+ )
52
+ block = Block(
53
+ d_model,
54
+ mixer_cls,
55
+ norm_cls=norm_cls,
56
+ fused_add_norm=fused_add_norm,
57
+ residual_in_fp32=residual_in_fp32,
58
+ )
59
+ block.layer_idx = layer_idx
60
+ if checkpoint_mixer:
61
+ block.mixer = CheckpointedModule(block.mixer)
62
+ return block
63
+
64
+ class CheckpointedModule(torch.nn.Module):
65
+ def __init__(self, layer):
66
+ super().__init__()
67
+ self.ckpt_layer = layer
68
+
69
+ def forward(self, x, *args, **kwargs):
70
+ return checkpoint(self.ckpt_layer, x, use_reentrant=False)
71
+
72
+ # def state_dict(self, **kwargs):
73
+ # # Get the state dict of the underlying layer
74
+ # layer_state_dict = self.ckpt_layer.state_dict(**kwargs)
75
+ # # Create a new state dict with the original keys
76
+ # state_dict = {k.replace('ckpt_layer.', ''): v for k, v in layer_state_dict.items()}
77
+ # return state_dict
78
+
79
+ class MixerModelSafe(MixerModel):
80
+ """
81
+ Overwrite the forward method to allow saving intermediate layers.
82
+ """
83
+
84
+ def forward(self, input_ids, inference_params=None, save_layer=[]):
85
+ hidden_states = self.embedding(input_ids)
86
+ residual = None
87
+ if len(save_layer) > 0:
88
+ hidden_states_dict = {}
89
+ for i, layer in enumerate(self.layers):
90
+ hidden_states, residual = layer(
91
+ hidden_states, residual, inference_params=inference_params
92
+ )
93
+ if i + 1 in save_layer:
94
+ hidden_states_dict[i + 1] = (
95
+ hidden_states.detach().cpu().to(torch.float).numpy()
96
+ )
97
+ if len(save_layer) > 0:
98
+ return hidden_states_dict
99
+
100
+ if not self.fused_add_norm:
101
+ residual = (
102
+ (hidden_states + residual) if residual is not None else hidden_states
103
+ )
104
+ hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
105
+ else:
106
+ # Set prenorm=False here since we don't need the residual
107
+ fused_add_norm_fn = (
108
+ rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
109
+ )
110
+ hidden_states = fused_add_norm_fn(
111
+ hidden_states,
112
+ self.norm_f.weight,
113
+ self.norm_f.bias,
114
+ eps=self.norm_f.eps,
115
+ residual=residual,
116
+ prenorm=False,
117
+ residual_in_fp32=self.residual_in_fp32,
118
+ )
119
+ return hidden_states
120
+
121
+ class MixerModelWithPosids(nn.Module):
122
+ r"""Mixer model for Mamba but we add positional encodings to the input embeddings."""
123
+
124
+ def __init__(
125
+ self,
126
+ d_model: int,
127
+ n_layer: int,
128
+ vocab_size: int,
129
+ max_position_embeddings: int,
130
+ ssm_cfg=None,
131
+ norm_epsilon: float = 1e-5,
132
+ rms_norm: bool = False,
133
+ initializer_cfg=None,
134
+ fused_add_norm=False,
135
+ residual_in_fp32=False,
136
+ device=None,
137
+ dtype=None,
138
+ checkpoint_mixer=False,
139
+ ) -> None:
140
+ factory_kwargs = {"device": device, "dtype": dtype}
141
+ super().__init__()
142
+ self.residual_in_fp32 = residual_in_fp32
143
+
144
+ self.embedding = nn.Embedding(vocab_size, d_model // 2, **factory_kwargs)
145
+ self.position_embedding = nn.Embedding(
146
+ max_position_embeddings, d_model - d_model // 2, **factory_kwargs
147
+ )
148
+
149
+ # We change the order of residual and layer norm:
150
+ # Instead of LN -> Attn / MLP -> Add, we do:
151
+ # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
152
+ # the main branch (output of MLP / Mixer). The model definition is unchanged.
153
+ # This is for performance reason: we can fuse add + layer_norm.
154
+ self.fused_add_norm = fused_add_norm
155
+ if self.fused_add_norm:
156
+ if layer_norm_fn is None or rms_norm_fn is None:
157
+ raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
158
+
159
+ self.layers = nn.ModuleList(
160
+ [
161
+ create_block(
162
+ d_model,
163
+ ssm_cfg=ssm_cfg,
164
+ norm_epsilon=norm_epsilon,
165
+ rms_norm=rms_norm,
166
+ residual_in_fp32=residual_in_fp32,
167
+ fused_add_norm=fused_add_norm,
168
+ layer_idx=i,
169
+ checkpoint_mixer=checkpoint_mixer,
170
+ **factory_kwargs,
171
+ )
172
+ for i in range(n_layer)
173
+ ]
174
+ )
175
+
176
+ self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
177
+ d_model, eps=norm_epsilon, **factory_kwargs
178
+ )
179
+
180
+ self.apply(
181
+ partial(
182
+ _init_weights,
183
+ n_layer=n_layer,
184
+ **(initializer_cfg if initializer_cfg is not None else {}),
185
+ )
186
+ )
187
+
188
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
189
+ return {
190
+ i: layer.allocate_inference_cache(
191
+ batch_size, max_seqlen, dtype=dtype, **kwargs
192
+ )
193
+ for i, layer in enumerate(self.layers)
194
+ }
195
+
196
+ def forward(self, input_ids, position_ids, inference_params=None, save_layer=[]):
197
+ hidden_states = torch.cat(
198
+ [
199
+ self.embedding(input_ids),
200
+ self.position_embedding(position_ids),
201
+ ],
202
+ -1,
203
+ )
204
+ residual = None
205
+ if len(save_layer) > 0:
206
+ hidden_states_dict = {}
207
+ for i, layer in enumerate(self.layers):
208
+ hidden_states, residual = layer(
209
+ hidden_states, residual, inference_params=inference_params
210
+ )
211
+ if i + 1 in save_layer:
212
+ hidden_states_dict[i + 1] = (
213
+ hidden_states.detach().cpu().to(torch.float).numpy()
214
+ )
215
+ if len(save_layer) > 0:
216
+ return hidden_states_dict
217
+
218
+ if not self.fused_add_norm:
219
+ residual = (
220
+ (hidden_states + residual) if residual is not None else hidden_states
221
+ )
222
+ hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
223
+ else:
224
+ fused_add_norm_fn = (
225
+ rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
226
+ )
227
+ hidden_states = fused_add_norm_fn(
228
+ hidden_states,
229
+ self.norm_f.weight,
230
+ self.norm_f.bias,
231
+ eps=self.norm_f.eps,
232
+ residual=residual,
233
+ prenorm=False,
234
+ residual_in_fp32=self.residual_in_fp32,
235
+ )
236
+ return hidden_states
237
+
238
+ class MixerModelWith2DPosids(nn.Module):
239
+ r"""Mixer model for Mamba but we add positional encodings to the input embeddings."""
240
+
241
+ def __init__(
242
+ self,
243
+ d_model: int,
244
+ n_layer: int,
245
+ vocab_size: int,
246
+ max_position_embeddings: int,
247
+ max_sequence_position_embeddings: int = 512,
248
+ ssm_cfg=None,
249
+ norm_epsilon: float = 1e-5,
250
+ rms_norm: bool = False,
251
+ initializer_cfg=None,
252
+ fused_add_norm=False,
253
+ residual_in_fp32=False,
254
+ device=None,
255
+ dtype=None,
256
+ checkpoint_mixer=False,
257
+ ) -> None:
258
+ factory_kwargs = {"device": device, "dtype": dtype}
259
+ super().__init__()
260
+ self.residual_in_fp32 = residual_in_fp32
261
+
262
+ self.embedding = nn.Embedding(
263
+ vocab_size, d_model - 2 * d_model // 4, **factory_kwargs
264
+ )
265
+ self.position_embedding = nn.Embedding(
266
+ max_position_embeddings, d_model // 4, **factory_kwargs
267
+ )
268
+ self.seq_position_embedding = nn.Embedding(
269
+ max_sequence_position_embeddings, d_model // 4, **factory_kwargs
270
+ )
271
+ self.d_embeddings = d_model - 2 * d_model // 4
272
+
273
+ # We change the order of residual and layer norm:
274
+ # Instead of LN -> Attn / MLP -> Add, we do:
275
+ # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
276
+ # the main branch (output of MLP / Mixer). The model definition is unchanged.
277
+ # This is for performance reason: we can fuse add + layer_norm.
278
+ self.fused_add_norm = fused_add_norm
279
+ if self.fused_add_norm:
280
+ if layer_norm_fn is None or rms_norm_fn is None:
281
+ raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
282
+
283
+ self.layers = nn.ModuleList(
284
+ [
285
+ create_block(
286
+ d_model,
287
+ ssm_cfg=ssm_cfg,
288
+ norm_epsilon=norm_epsilon,
289
+ rms_norm=rms_norm,
290
+ residual_in_fp32=residual_in_fp32,
291
+ fused_add_norm=fused_add_norm,
292
+ layer_idx=i,
293
+ checkpoint_mixer=checkpoint_mixer,
294
+ **factory_kwargs,
295
+ )
296
+ for i in range(n_layer)
297
+ ]
298
+ )
299
+
300
+ self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
301
+ d_model, eps=norm_epsilon, **factory_kwargs
302
+ )
303
+
304
+ self.apply(
305
+ partial(
306
+ _init_weights,
307
+ n_layer=n_layer,
308
+ **(initializer_cfg if initializer_cfg is not None else {}),
309
+ )
310
+ )
311
+
312
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
313
+ return {
314
+ i: layer.allocate_inference_cache(
315
+ batch_size, max_seqlen, dtype=dtype, **kwargs
316
+ )
317
+ for i, layer in enumerate(self.layers)
318
+ }
319
+
320
+ def forward(
321
+ self,
322
+ input_ids,
323
+ position_ids,
324
+ seq_position_ids,
325
+ inference_params=None,
326
+ save_layer=[],
327
+ ):
328
+ hidden_states = torch.cat(
329
+ [
330
+ self.embedding(input_ids),
331
+ self.position_embedding(position_ids),
332
+ self.seq_position_embedding(seq_position_ids),
333
+ ],
334
+ -1,
335
+ )
336
+ residual = None
337
+ if len(save_layer) > 0:
338
+ hidden_states_dict = {}
339
+ for i, layer in enumerate(self.layers):
340
+ hidden_states, residual = layer(
341
+ hidden_states, residual, inference_params=inference_params
342
+ )
343
+ if i + 1 in save_layer:
344
+ hidden_states_dict[i + 1] = (
345
+ hidden_states.detach().cpu().to(torch.float).numpy()
346
+ )
347
+ if len(save_layer) > 0:
348
+ return hidden_states_dict
349
+
350
+ if not self.fused_add_norm:
351
+ residual = (
352
+ (hidden_states + residual) if residual is not None else hidden_states
353
+ )
354
+ hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
355
+ else:
356
+ fused_add_norm_fn = (
357
+ rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
358
+ )
359
+ hidden_states = fused_add_norm_fn(
360
+ hidden_states,
361
+ self.norm_f.weight,
362
+ self.norm_f.bias,
363
+ eps=self.norm_f.eps,
364
+ residual=residual,
365
+ prenorm=False,
366
+ residual_in_fp32=self.residual_in_fp32,
367
+ )
368
+ return hidden_states
369
+
370
+ class MambaLMHeadModelSafe(nn.Module, GenerationMixinSafe):
371
+
372
+ def __init__(
373
+ self,
374
+ config: MambaConfig,
375
+ initializer_cfg=None,
376
+ device=None,
377
+ dtype=None,
378
+ checkpoint_mixer=False,
379
+ ) -> None:
380
+ self.config = config
381
+ d_model = config.d_model
382
+ n_layer = config.n_layer
383
+ vocab_size = config.vocab_size
384
+ ssm_cfg = config.ssm_cfg
385
+ rms_norm = config.rms_norm
386
+ residual_in_fp32 = config.residual_in_fp32
387
+ fused_add_norm = config.fused_add_norm
388
+ pad_vocab_size_multiple = config.pad_vocab_size_multiple
389
+ factory_kwargs = {"device": device, "dtype": dtype}
390
+ if checkpoint_mixer:
391
+ raise NotImplementedError(
392
+ "Checkpointing is not yet supported for MambaLMHeadModelSafe"
393
+ )
394
+
395
+ super().__init__()
396
+ if vocab_size % pad_vocab_size_multiple != 0:
397
+ vocab_size += pad_vocab_size_multiple - (
398
+ vocab_size % pad_vocab_size_multiple
399
+ )
400
+ self.backbone = MixerModelSafe(
401
+ d_model=d_model,
402
+ n_layer=n_layer,
403
+ vocab_size=vocab_size,
404
+ ssm_cfg=ssm_cfg,
405
+ rms_norm=rms_norm,
406
+ initializer_cfg=initializer_cfg,
407
+ fused_add_norm=fused_add_norm,
408
+ residual_in_fp32=residual_in_fp32,
409
+ **factory_kwargs,
410
+ )
411
+ self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
412
+
413
+ # Initialize weights and apply final processing
414
+ self.apply(
415
+ partial(
416
+ _init_weights,
417
+ n_layer=n_layer,
418
+ **(initializer_cfg if initializer_cfg is not None else {}),
419
+ )
420
+ )
421
+ self.tie_weights()
422
+
423
+ def tie_weights(self):
424
+ self.lm_head.weight = self.backbone.embedding.weight
425
+
426
+ def clip_grad_norm_(self, max_norm, norm_type=2.0):
427
+ r"""Clip the norm of the gradients for the model.
428
+ Args:
429
+ max_norm (float or int): The maximum norm of the gradients.
430
+ The gradients are modified in-place.
431
+ norm_type (float or int): The type of the used p-norm. Can be 'inf' for infinity norm.
432
+ Returns:
433
+ Total norm of the parameters (viewed as a single vector).
434
+ """
435
+ return torch.nn.utils.clip_grad_value_(self.parameters(), max_norm)
436
+
437
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
438
+ return self.backbone.allocate_inference_cache(
439
+ batch_size, max_seqlen, dtype=dtype, **kwargs
440
+ )
441
+
442
+ def forward(
443
+ self,
444
+ input_ids,
445
+ position_ids=None,
446
+ inference_params=None,
447
+ num_last_tokens=0,
448
+ save_layer=[],
449
+ *args,
450
+ **kwargs,
451
+ ):
452
+ """
453
+ "position_ids" is just to be compatible with Transformer generation. We don't use it.
454
+ num_last_tokens: if > 0, only return the logits for the last n tokens
455
+ """
456
+ return self.protected_forward(
457
+ input_ids, position_ids, inference_params, num_last_tokens, save_layer
458
+ )
459
+
460
+ def protected_forward(
461
+ self,
462
+ input_ids,
463
+ position_ids=None,
464
+ inference_params=None,
465
+ num_last_tokens=0,
466
+ save_layer=[],
467
+ ):
468
+ hidden_states = self.backbone(
469
+ input_ids, inference_params=inference_params, save_layer=save_layer
470
+ )
471
+ if len(save_layer) > 0:
472
+ return hidden_states
473
+ if num_last_tokens > 0:
474
+ hidden_states = hidden_states[:, -num_last_tokens:]
475
+ lm_logits = self.lm_head(hidden_states)
476
+ CausalLMOutput = namedtuple("CausalLMOutput", ["loss", "logits"])
477
+ return CausalLMOutput(loss=None, logits=lm_logits)
478
+
479
+ @classmethod
480
+ def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
481
+ config_data = load_config_hf(pretrained_model_name)
482
+ config = MambaConfig(**config_data)
483
+ model = cls(config, device=device, dtype=dtype, **kwargs)
484
+ model.load_state_dict(
485
+ load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype),
486
+ strict=False,
487
+ )
488
+ return model
489
+
490
+ def save_pretrained(self, save_directory):
491
+ """
492
+ Minimal implementation of save_pretrained for MambaLMHeadModel.
493
+ Save the model and its configuration file to a directory.
494
+ """
495
+ # Ensure save_directory exists
496
+ os.makedirs(save_directory, exist_ok=True)
497
+
498
+ # Save the model's state_dict
499
+ model_path = os.path.join(save_directory, "pytorch_model.bin")
500
+ torch.save(self.state_dict(), model_path)
501
+
502
+ # Save the configuration of the model
503
+ config_path = os.path.join(save_directory, "config.json")
504
+ with open(config_path, "w") as f:
505
+ json.dump(self.config.__dict__, f)
506
+
507
+ class MambaLMHeadModelwithPosids(nn.Module, GenerationMixinSafe):
508
+
509
+ def __init__(
510
+ self,
511
+ config: MambaConfig,
512
+ initializer_cfg=None,
513
+ device=None,
514
+ dtype=None,
515
+ checkpoint_mixer=False,
516
+ ) -> None:
517
+ self.config = config
518
+ d_model = config.d_model
519
+ n_layer = config.n_layer
520
+ vocab_size = config.vocab_size
521
+ max_position_embeddings = config.max_position_embeddings
522
+ ssm_cfg = config.ssm_cfg
523
+ rms_norm = config.rms_norm
524
+ residual_in_fp32 = config.residual_in_fp32
525
+ fused_add_norm = config.fused_add_norm
526
+ pad_vocab_size_multiple = config.pad_vocab_size_multiple
527
+ factory_kwargs = {"device": device, "dtype": dtype}
528
+
529
+ super().__init__()
530
+ if vocab_size % pad_vocab_size_multiple != 0:
531
+ vocab_size += pad_vocab_size_multiple - (
532
+ vocab_size % pad_vocab_size_multiple
533
+ )
534
+ self.backbone = MixerModelWithPosids(
535
+ d_model=d_model,
536
+ n_layer=n_layer,
537
+ vocab_size=vocab_size,
538
+ max_position_embeddings=max_position_embeddings,
539
+ ssm_cfg=ssm_cfg,
540
+ rms_norm=rms_norm,
541
+ initializer_cfg=initializer_cfg,
542
+ fused_add_norm=fused_add_norm,
543
+ residual_in_fp32=residual_in_fp32,
544
+ checkpoint_mixer=checkpoint_mixer,
545
+ **factory_kwargs,
546
+ )
547
+ self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
548
+
549
+ # Initialize weights and apply final processing
550
+ self.apply(
551
+ partial(
552
+ _init_weights,
553
+ n_layer=n_layer,
554
+ **(initializer_cfg if initializer_cfg is not None else {}),
555
+ )
556
+ )
557
+ self.tie_weights()
558
+
559
+ def tie_weights(self):
560
+ self.lm_head.weight = self.backbone.embedding.weight
561
+
562
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
563
+ return self.backbone.allocate_inference_cache(
564
+ batch_size, max_seqlen, dtype=dtype, **kwargs
565
+ )
566
+
567
+ def forward(
568
+ self,
569
+ input_ids,
570
+ position_ids=None,
571
+ inference_params=None,
572
+ num_last_tokens=0,
573
+ save_layer=[],
574
+ *args,
575
+ **kwargs,
576
+ ):
577
+ """
578
+ "position_ids" is just to be compatible with Transformer generation. We don't use it.
579
+ num_last_tokens: if > 0, only return the logits for the last n tokens
580
+ """
581
+ return self.protected_forward(
582
+ input_ids, position_ids, inference_params, num_last_tokens, save_layer
583
+ )
584
+
585
+ def protected_forward(
586
+ self,
587
+ input_ids,
588
+ position_ids=None,
589
+ inference_params=None,
590
+ num_last_tokens=0,
591
+ save_layer=[],
592
+ ):
593
+ hidden_states = self.backbone(
594
+ input_ids,
595
+ position_ids=position_ids,
596
+ inference_params=inference_params,
597
+ save_layer=save_layer,
598
+ )
599
+ if len(save_layer) > 0:
600
+ return hidden_states
601
+ hidden_states = hidden_states[:, :, : self.config.d_model // 2]
602
+ if num_last_tokens > 0:
603
+ hidden_states = hidden_states[:, -num_last_tokens:]
604
+ lm_logits = self.lm_head(hidden_states)
605
+ CausalLMOutput = namedtuple("CausalLMOutput", ["loss", "logits"])
606
+ return CausalLMOutput(loss=None, logits=lm_logits)
607
+
608
+ @classmethod
609
+ def from_pretrained(
610
+ cls,
611
+ pretrained_model_name,
612
+ device=None,
613
+ dtype=None,
614
+ checkpoint_mixer=False,
615
+ **kwargs,
616
+ ):
617
+ config_data = load_config_hf(pretrained_model_name)
618
+ config = MambaConfig(**config_data)
619
+ model = cls(
620
+ config,
621
+ device=device,
622
+ dtype=dtype,
623
+ checkpoint_mixer=checkpoint_mixer,
624
+ **kwargs,
625
+ )
626
+ state_dict = load_state_dict_hf(
627
+ pretrained_model_name, device=device, dtype=dtype
628
+ )
629
+ if state_dict.keys() != model.state_dict().keys():
630
+ if checkpoint_mixer:
631
+ for key in model.state_dict().keys():
632
+ if "ckpt_layer" in key:
633
+ state_dict[key] = state_dict.pop(key.replace("ckpt_layer.", ""))
634
+ print(
635
+ "Using a model that was pretrained without gradient checkpointing and now want to use it. Changed the keys of the state_dict to match the model's keys."
636
+ )
637
+ else:
638
+ for key in list(state_dict.keys()):
639
+ if "ckpt_layer" in key:
640
+ state_dict[key.replace("ckpt_layer.", "")] = state_dict.pop(key)
641
+ print(
642
+ "Using a model that was pretrained with gradient checkpointing but now do not want to use it. Changed the keys of the state_dict to match the model's keys."
643
+ )
644
+ assert (
645
+ state_dict.keys() == model.state_dict().keys()
646
+ ), "The keys of the state_dict do not match the model's keys."
647
+ model.load_state_dict(state_dict)
648
+ return model
649
+
650
+ def save_pretrained(self, save_directory):
651
+ """
652
+ Minimal implementation of save_pretrained for MambaLMHeadModel.
653
+ Save the model and its configuration file to a directory.
654
+ """
655
+ # Ensure save_directory exists
656
+ os.makedirs(save_directory, exist_ok=True)
657
+
658
+ # Save the model's state_dict
659
+ model_path = os.path.join(save_directory, "pytorch_model.bin")
660
+ torch.save(self.state_dict(), model_path)
661
+
662
+ # Save the configuration of the model
663
+ config_path = os.path.join(save_directory, "config.json")
664
+ with open(config_path, "w") as f:
665
+ json.dump(self.config.__dict__, f)
666
+
667
+ class MambaLMHeadModelwith2DPosids(nn.Module, GenerationMixinSafe):
668
+
669
+ def __init__(
670
+ self,
671
+ config: MambaConfig,
672
+ initializer_cfg=None,
673
+ device=None,
674
+ dtype=None,
675
+ checkpoint_mixer=False,
676
+ ) -> None:
677
+ self.config = config
678
+ d_model = config.d_model
679
+ n_layer = config.n_layer
680
+ vocab_size = config.vocab_size
681
+ max_position_embeddings = config.max_position_embeddings
682
+ ssm_cfg = config.ssm_cfg
683
+ rms_norm = config.rms_norm
684
+ residual_in_fp32 = config.residual_in_fp32
685
+ fused_add_norm = config.fused_add_norm
686
+ pad_vocab_size_multiple = config.pad_vocab_size_multiple
687
+ factory_kwargs = {"device": device, "dtype": dtype}
688
+
689
+ super().__init__()
690
+ if vocab_size % pad_vocab_size_multiple != 0:
691
+ vocab_size += pad_vocab_size_multiple - (
692
+ vocab_size % pad_vocab_size_multiple
693
+ )
694
+ self.backbone = MixerModelWith2DPosids(
695
+ d_model=d_model,
696
+ n_layer=n_layer,
697
+ vocab_size=vocab_size,
698
+ max_position_embeddings=max_position_embeddings,
699
+ ssm_cfg=ssm_cfg,
700
+ rms_norm=rms_norm,
701
+ initializer_cfg=initializer_cfg,
702
+ fused_add_norm=fused_add_norm,
703
+ residual_in_fp32=residual_in_fp32,
704
+ checkpoint_mixer=checkpoint_mixer,
705
+ **factory_kwargs,
706
+ )
707
+ self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
708
+
709
+ # Initialize weights and apply final processing
710
+ self.apply(
711
+ partial(
712
+ _init_weights,
713
+ n_layer=n_layer,
714
+ **(initializer_cfg if initializer_cfg is not None else {}),
715
+ )
716
+ )
717
+ self.tie_weights()
718
+
719
+ def tie_weights(self):
720
+ self.lm_head.weight = self.backbone.embedding.weight
721
+
722
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
723
+ return self.backbone.allocate_inference_cache(
724
+ batch_size, max_seqlen, dtype=dtype, **kwargs
725
+ )
726
+
727
+ def forward(
728
+ self,
729
+ input_ids,
730
+ position_ids=None,
731
+ seq_position_ids=None,
732
+ inference_params=None,
733
+ num_last_tokens=0,
734
+ save_layer=[],
735
+ *args,
736
+ **kwargs,
737
+ ):
738
+ """
739
+ "position_ids" is just to be compatible with Transformer generation. We don't use it.
740
+ num_last_tokens: if > 0, only return the logits for the last n tokens
741
+ """
742
+ return self.protected_forward(
743
+ input_ids,
744
+ position_ids,
745
+ seq_position_ids,
746
+ inference_params,
747
+ num_last_tokens,
748
+ save_layer,
749
+ )
750
+
751
+ def protected_forward(
752
+ self,
753
+ input_ids,
754
+ position_ids=None,
755
+ seq_position_ids=None,
756
+ inference_params=None,
757
+ num_last_tokens=0,
758
+ save_layer=[],
759
+ ):
760
+ hidden_states = self.backbone(
761
+ input_ids,
762
+ position_ids=position_ids,
763
+ seq_position_ids=seq_position_ids,
764
+ inference_params=inference_params,
765
+ save_layer=save_layer,
766
+ )
767
+ if len(save_layer) > 0:
768
+ return hidden_states
769
+ hidden_states = hidden_states[:, :, : self.backbone.d_embeddings]
770
+ if num_last_tokens > 0:
771
+ hidden_states = hidden_states[:, -num_last_tokens:]
772
+ lm_logits = self.lm_head(hidden_states)
773
+ CausalLMOutput = namedtuple("CausalLMOutput", ["loss", "logits"])
774
+ return CausalLMOutput(loss=None, logits=lm_logits)
775
+
776
+ @classmethod
777
+ def from_pretrained(
778
+ cls,
779
+ pretrained_model_name,
780
+ device=None,
781
+ dtype=None,
782
+ checkpoint_mixer=False,
783
+ **kwargs,
784
+ ):
785
+ config_data = load_config_hf(pretrained_model_name)
786
+ config = MambaConfig(**config_data)
787
+ model = cls(
788
+ config,
789
+ device=device,
790
+ dtype=dtype,
791
+ checkpoint_mixer=checkpoint_mixer,
792
+ **kwargs,
793
+ )
794
+ state_dict = load_state_dict_hf(
795
+ pretrained_model_name, device=device, dtype=dtype
796
+ )
797
+ if state_dict.keys() != model.state_dict().keys():
798
+ if checkpoint_mixer:
799
+ for key in model.state_dict().keys():
800
+ if "ckpt_layer" in key:
801
+ state_dict[key] = state_dict.pop(key.replace("ckpt_layer.", ""))
802
+ print(
803
+ "Using a model that was pretrained without gradient checkpointing and now want to use it. Changed the keys of the state_dict to match the model's keys."
804
+ )
805
+ else:
806
+ for key in list(state_dict.keys()):
807
+ if "ckpt_layer" in key:
808
+ state_dict[key.replace("ckpt_layer.", "")] = state_dict.pop(key)
809
+ print(
810
+ "Using a model that was pretrained with gradient checkpointing but now do not want to use it. Changed the keys of the state_dict to match the model's keys."
811
+ )
812
+ assert (
813
+ state_dict.keys() == model.state_dict().keys()
814
+ ), "The keys of the state_dict do not match the model's keys."
815
+ model.load_state_dict(state_dict)
816
+ return model
817
+
818
+ def save_pretrained(self, save_directory):
819
+ """
820
+ Minimal implementation of save_pretrained for MambaLMHeadModel.
821
+ Save the model and its configuration file to a directory.
822
+ """
823
+ # Ensure save_directory exists
824
+ os.makedirs(save_directory, exist_ok=True)
825
+
826
+ # Save the model's state_dict
827
+ model_path = os.path.join(save_directory, "pytorch_model.bin")
828
+ torch.save(self.state_dict(), model_path)
829
+
830
+ # Save the configuration of the model
831
+ config_path = os.path.join(save_directory, "config.json")
832
+ with open(config_path, "w") as f:
833
+ json.dump(self.config.__dict__, f)
protxlstm/models/xlstm.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __all__ = [
2
+ "xLSTMConfig",
3
+ "xLSTMLMHeadModel",
4
+ ]
5
+
6
+ import json
7
+ import os
8
+ from collections import namedtuple
9
+ from dataclasses import asdict
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from dacite import Config as DaciteConfig, from_dict
14
+ from omegaconf import OmegaConf
15
+ from transformers import PretrainedConfig
16
+
17
+ from protxlstm.generation import GenerationMixinSafe
18
+ from protxlstm.utils import load_config_hf, load_state_dict_hf
19
+ from protxlstm.xlstm.xlstm_lm_model import xLSTMLMModel, xLSTMLMModelConfig
20
+
21
+
22
+ class xLSTMConfig(PretrainedConfig):
23
+
24
+ def __init__(self):
25
+ self.config_dataclass = xLSTMLMModelConfig()
26
+
27
+ def init_from_dict(self, config: dict):
28
+ config = OmegaConf.create(config)
29
+ self.config_dataclass = from_dict(
30
+ data_class=xLSTMLMModelConfig,
31
+ data=OmegaConf.to_container(config),
32
+ config=DaciteConfig(strict=True),
33
+ )
34
+ return self
35
+
36
+ def to_dict(self):
37
+ return asdict(self.config_dataclass)
38
+
39
+
40
+ class xLSTMLMHeadModel(nn.Module, GenerationMixinSafe):
41
+
42
+ def __init__(self, config: xLSTMConfig) -> None:
43
+ super().__init__()
44
+
45
+ self.config = config
46
+ self.backbone = xLSTMLMModel(self.config.config_dataclass)
47
+ self.backbone.reset_parameters()
48
+
49
+ self.setup()
50
+
51
+
52
+ def setup(self):
53
+
54
+ if 'LOCAL_RANK' in os.environ:
55
+ current_device = int(os.environ['LOCAL_RANK'])
56
+ else:
57
+ if 'SLURM_LOCALID' in os.environ:
58
+ current_device = int(os.environ['SLURM_LOCALID'])
59
+ else:
60
+ current_device = 0
61
+
62
+ #torch.cuda.set_device(f'cuda:{current_device}')
63
+
64
+ #self.backbone = self.backbone.to("cuda")
65
+
66
+
67
+ def forward(
68
+ self,
69
+ input_ids,
70
+ state=None,
71
+ position_ids=None,
72
+ seq_position_ids=None,
73
+ inference_params=None,
74
+ num_last_tokens=0,
75
+ save_layer=[],
76
+ **kwargs,
77
+ ):
78
+
79
+ if self.config.config_dataclass.mlstm_block.mlstm.return_last_state:
80
+ lm_logits, state = self.backbone(input_ids, position_ids=position_ids, seq_position_ids=seq_position_ids, state=state)
81
+ CausalLMOutput = namedtuple("CausalLMOutput", ["loss", "logits", "state"])
82
+ return CausalLMOutput(loss=None, logits=lm_logits, state=state)
83
+ else:
84
+ lm_logits = self.backbone(input_ids, position_ids=position_ids, seq_position_ids=seq_position_ids, state=state)
85
+ CausalLMOutput = namedtuple("CausalLMOutput", ["loss", "logits"])
86
+ return CausalLMOutput(loss=None, logits=lm_logits)
87
+
88
+ def step(
89
+ self,
90
+ input_ids,
91
+ state=None,
92
+ position_ids=None,
93
+ seq_position_ids=None,
94
+ inference_params=None,
95
+ num_last_tokens=0,
96
+ save_layer=[],
97
+ **kwargs,
98
+ ):
99
+
100
+ lm_logits, state = self.backbone.step(
101
+ input_ids, state=state, position_ids=position_ids, seq_position_ids=seq_position_ids
102
+ )
103
+
104
+ return lm_logits, state
105
+
106
+
107
+ @classmethod
108
+ def from_pretrained(
109
+ cls,
110
+ pretrained_model_name,
111
+ device=None,
112
+ dtype=None,
113
+ mlstm_backend=None,
114
+ mlstm_chunksize=None,
115
+ checkpoint_blocks=None,
116
+ rope_base_frequency=None,
117
+ mlstm_return_last_state=None,
118
+ ):
119
+ # Load the checkpoint config
120
+ config_dict = load_config_hf(pretrained_model_name)
121
+
122
+ # update rope base frequency
123
+ if rope_base_frequency is not None and config_dict.get("rope_base_frequency", None) != rope_base_frequency:
124
+ config_dict["rope_base_frequency"] = rope_base_frequency
125
+
126
+ # update mlstm backend
127
+ if mlstm_backend is not None and config_dict["mlstm_block"]["mlstm"].get("backend", None) != mlstm_backend:
128
+ assert mlstm_backend in ["chunkwise", "chunkwise_variable", "parallel"], "invalid mlstm backend."
129
+ config_dict["mlstm_block"]["mlstm"]["backend"] = mlstm_backend
130
+
131
+ # update mlstm chunksize
132
+ if mlstm_chunksize is not None and config_dict["mlstm_block"]["mlstm"].get("chunk_size", None) != mlstm_chunksize:
133
+ config_dict["mlstm_block"]["mlstm"]["chunk_size"] = mlstm_chunksize
134
+
135
+ # update activation checkpointing
136
+ if checkpoint_blocks is not None:
137
+ config_dict["checkpoint_blocks"] = checkpoint_blocks
138
+
139
+ if mlstm_return_last_state is not None:
140
+ config_dict["mlstm_block"]["mlstm"]["return_last_state"] = mlstm_return_last_state
141
+
142
+ if "slstm_block" in config_dict:
143
+ config_dict.pop("slstm_block")
144
+
145
+ if "slstm_at" in config_dict:
146
+ config_dict.pop("slstm_at")
147
+
148
+ config = xLSTMConfig().init_from_dict(config_dict)
149
+
150
+ model = cls(config)
151
+
152
+ state_dict = load_state_dict_hf(
153
+ pretrained_model_name, device=device, dtype=dtype
154
+ )
155
+ assert (
156
+ state_dict.keys() == model.state_dict().keys()
157
+ ), "The keys of the state_dict do not match the model's keys."
158
+
159
+ model.load_state_dict(state_dict)
160
+
161
+ return model
162
+
163
+ def save_pretrained(self, save_directory):
164
+ """
165
+ Save the model and its configuration file to a directory.
166
+ """
167
+
168
+ # Ensure save_directory exists
169
+ os.makedirs(save_directory, exist_ok=True)
170
+
171
+ # Save the model's state_dict
172
+ model_path = os.path.join(save_directory, "pytorch_model.bin")
173
+ torch.save(self.state_dict(), model_path)
174
+
175
+ # Save the configuration of the model
176
+ config_path = os.path.join(save_directory, "config.json")
177
+ with open(config_path, "w") as f:
178
+ json.dump(self.config.to_dict(), f)
179
+
180
+
protxlstm/plot_utils.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ cd = { # use dependent on model-type!!
3
+ "xLSTM": "#3073AD",
4
+ "Transformers": "#4B9D7A",
5
+ "Mamba": "#DF8953",
6
+ "S4": "#D275AB",
7
+ "Hyena": "#E86A61",
8
+ }
9
+
10
+ def setup_matplotlib():
11
+ import matplotlib.pyplot as plt
12
+ from tueplots import bundles, axes
13
+ bundles.icml2022()
14
+ plt.rcParams.update(bundles.icml2022())
15
+ plt.rcParams.update(axes.lines(base_width=0.5))
16
+ plt.rcParams["text.usetex"] = False
17
+ plt.rcParams['font.family'] = "sans-serif"
18
+ plt.rcParams['font.serif'] = 'Arial'
19
+ plt.rcParams['legend.edgecolor'] = 'grey'
20
+ plt.rcParams['legend.framealpha'] = 0.7
21
+ plt.rcParams['lines.linewidth'] = 1.2
22
+ plt.rcParams['axes.grid'] = True
23
+ plt.rcParams['axes.grid.axis'] = 'both'
24
+ plt.rcParams['grid.alpha'] = 0.2
25
+ plt.rcParams['axes.grid'] = True
26
+ plt.rcParams['axes.prop_cycle'] = plt.cycler(color=cd.values())
protxlstm/train.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Original code from ProtMamba under Apache License 2.0.
2
+ #
3
+ # Modifications made by Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen
4
+ # - Extended to training of xlstm and transformer-based models
5
+ # - Predefined splits instead of on-the-fly creation
6
+ # - Option to overwrite config parameters from the command line
7
+ # - wandb logging
8
+
9
+ import argparse
10
+ import os
11
+
12
+ import torch
13
+ from omegaconf import OmegaConf
14
+ from transformers import TrainingArguments
15
+
16
+ from protxlstm.dataloaders import ProteinMemmapDataset, ProteinDataCollator
17
+ from protxlstm.models.xlstm import xLSTMConfig, xLSTMLMHeadModel
18
+ from protxlstm.models.llama import TransformerConfig, TransformerLMHeadModel
19
+ from protxlstm.trainer import ProtTrainer, EarlyStoppingCallback, get_last_checkpoint
20
+ from protxlstm.utils import (
21
+ AA_TO_ID,
22
+ compute_metrics,
23
+ is_zero_rank,
24
+ parse_override_args,
25
+ print_number_of_parameters,
26
+ print_zero_rank,
27
+ set_optimizer_and_scheduler,
28
+ setup_wandb,
29
+ load_model,
30
+ )
31
+
32
+ def run(config):
33
+ """
34
+ Run training loop.
35
+
36
+ Args:
37
+ config (dict): dictionary with the configuration parameters.
38
+ """
39
+
40
+ if config.model_type == 'llama':
41
+ pe_kwargs = {
42
+ 'max_position_embeddings' : config["model"]["max_position_embeddings"],
43
+ 'add_position_ids' : '1d',
44
+ }
45
+ elif config.model_type == 'mamba':
46
+ from protxlstm.models.mamba import MambaConfig, MambaLMHeadModelSafe, MambaLMHeadModelwithPosids, MambaLMHeadModelwith2DPosids
47
+ pe_kwargs = {
48
+ 'max_position_embeddings' : config["model"]["max_position_embeddings"],
49
+ 'max_seq_position_embeddings' : config["model"]["max_seq_position_embeddings"],
50
+ 'add_position_ids' : config["model"]["add_position_ids"]
51
+ }
52
+ else:
53
+ position_embeddings = config["model"]["position_embeddings"]
54
+ assert position_embeddings in ["none", "abs_1d", "abs_2d", "rot_1d", "rot_2d"]
55
+ if position_embeddings != "none":
56
+ position_embeddings = position_embeddings.split("_")[-1]
57
+ pe_kwargs = {
58
+ 'max_position_embeddings' : config["model"]["max_position_embeddings"],
59
+ 'max_seq_position_embeddings' : config["model"]["max_seq_position_embeddings"],
60
+ 'add_position_ids' : position_embeddings
61
+ }
62
+
63
+ # Setup WandB
64
+ wandb_run_name = setup_wandb(config)
65
+
66
+ # Load datasets
67
+ dataset_params = {
68
+ "msa_memmap_path": config["msa_memmap_path"],
69
+ "msa_memmap_meta_path": config["msa_memmap_meta_path"],
70
+ "sample": config["sample_sequences"],
71
+ "max_msa_len": config["max_msa_len"],
72
+ "reverse": False,
73
+ "seed": config["seed_sequence_sampling"],
74
+ "troubleshoot": False,
75
+ "fim_strategy": config["fim_strategy"],
76
+ "always_mask": config["always_mask"],
77
+ **pe_kwargs,
78
+ }
79
+ train_dataset = ProteinMemmapDataset(subset_path=config["train_set"], **dataset_params)
80
+ valid_dataset = ProteinMemmapDataset(subset_path=config["valid_set"], **dataset_params)
81
+ train_eval_dataset = ProteinMemmapDataset(subset_path=config["train_eval_set"], **dataset_params)
82
+
83
+ print(f'Train set size: {len(train_dataset)} Train eval set size: {len(train_eval_dataset)} Valid set size: {len(valid_dataset)}')
84
+
85
+ assert (
86
+ len(AA_TO_ID) == config["model"]["vocab_size"]
87
+ ), f"Vocab size in the config file does not match the one in the code. I should be {len(AA_TO_ID)}"
88
+
89
+ # Create data collator for batched training
90
+ data_collator = ProteinDataCollator(max_sequence_length=config["max_msa_len"])
91
+
92
+ # Check datatypes
93
+ if config["dtype"] == "float32":
94
+ dtype = torch.float32
95
+ elif config["dtype"] == "bfloat16":
96
+ dtype = torch.bfloat16
97
+ else:
98
+ raise ValueError("dtype must be either float32 or bfloat16")
99
+
100
+ # Initialize model
101
+ if config.model_type == 'xlstm':
102
+
103
+ # Load model for finetuning
104
+ if config.finetune_model_path:
105
+ # These fields are updated in the config loaded from the checkpoint
106
+ config_update_kwargs = {
107
+ "mlstm_backend": config["model"]["mlstm_block"]["mlstm"]["backend"],
108
+ "mlstm_chunksize": config["model"]["mlstm_block"]["mlstm"]["chunk_size"],
109
+ "checkpoint_blocks": config["model"]["checkpoint_blocks"],
110
+ "rope_base_frequency": config["model"]["rope_base_frequency"]
111
+ }
112
+ model = load_model(
113
+ config.finetune_model_path,
114
+ model_class=xLSTMLMHeadModel,
115
+ device="cuda",
116
+ dtype=dtype,
117
+ **config_update_kwargs
118
+ )
119
+ else:
120
+ # Create new mode
121
+ xlstm_config = xLSTMConfig().init_from_dict(config["model"])
122
+ model = xLSTMLMHeadModel(xlstm_config)
123
+
124
+ elif config.model_type == 'mamba':
125
+
126
+ _mamba_model = {
127
+ "none": MambaLMHeadModelSafe,
128
+ "1d": MambaLMHeadModelwithPosids,
129
+ "2d": MambaLMHeadModelwith2DPosids,
130
+ }
131
+ Mamba = _mamba_model[config['model']["add_position_ids"]]
132
+
133
+ # Load model for finetuning
134
+ if config.finetune_model_path:
135
+ model = load_model(
136
+ config.finetune_model_path,
137
+ model_class=Mamba,
138
+ device="cuda",
139
+ dtype=dtype,
140
+ checkpoint_mixer=config["checkpoint_mixer"],
141
+ )
142
+ else:
143
+ # Create new mode
144
+ mamba_config = MambaConfig(d_model=config['model']["d_model"],
145
+ n_layer=config['model']["n_layer"],
146
+ vocab_size=config['model']["vocab_size"],
147
+ residual_in_fp32=config['model']["residual_in_fp32"])
148
+ model = Mamba(mamba_config, dtype=dtype, checkpoint_mixer=config['model']["checkpoint_mixer"])
149
+
150
+ elif config.model_type == 'llama':
151
+
152
+ llama_config = TransformerConfig(
153
+ d_model=config["model"]["d_model"],
154
+ n_layer=config["model"]["n_layer"],
155
+ n_heads=config["model"]["n_heads"],
156
+ n_kv_heads=config["model"]["n_kv_heads"],
157
+ bidirectional=config["model"]["bidirectional"],
158
+ hidden_dim=config["model"]["hidden_dim"],
159
+ multiple_of=config["model"]["multiple_of"],
160
+ norm_eps=config["model"]["norm_eps"],
161
+ max_length=config["model"]["max_length"],
162
+ vocab_size=config["model"]["vocab_size"],
163
+ dropout=config["model"]["dropout"],
164
+ max_position_embeddings=config["model"]["max_position_embeddings"],
165
+ rope_base_frequency=config["model"]["rope_base_frequency"],
166
+
167
+ )
168
+
169
+ model = TransformerLMHeadModel(llama_config)
170
+
171
+ else:
172
+ raise ValueError(f"Unsupported model_type: {config.model_type}. Expected 'xlstm', 'mamba', or 'llama'.")
173
+
174
+
175
+ # TODO: Improve what we want print
176
+ if is_zero_rank():
177
+ print_number_of_parameters(model)
178
+ print_zero_rank(f"dtype: {config['dtype']}")
179
+ print_zero_rank(f"Epochs: {config['num_epochs']}")
180
+ print_zero_rank(f"Batch size per GPU: {config['batch_size']}")
181
+ print_zero_rank(f"Gradient accumulation steps: {config['gradient_accumulation_steps']}")
182
+ eff_batch_size = config["batch_size"] * config["gradient_accumulation_steps"]
183
+ nr_gpus = torch.cuda.device_count()
184
+ print_zero_rank(f"GPUS: {nr_gpus}")
185
+ eff_batch_size *= nr_gpus
186
+ print_zero_rank(f"Effective batch size: {eff_batch_size}")
187
+ print_zero_rank(
188
+ f"Steps per training epoch: {len(train_dataset) // config['batch_size']}, eff. steps: {len(train_dataset) // eff_batch_size}"
189
+ )
190
+ print_zero_rank(f"Steps per evaluation epoch: {len(valid_dataset) // config['batch_size']}")
191
+ print_zero_rank(f"Max MSA length: {config['max_msa_len']}")
192
+ ev_epochs = round(
193
+ config["eval_steps"] * config["batch_size"] / len(train_dataset), 3
194
+ )
195
+ print_zero_rank(
196
+ f"Evaluation every {config['eval_steps']} steps, i.e. {ev_epochs} epochs. Effectively every {config['eval_steps']*config['gradient_accumulation_steps']} steps, i.e. {ev_epochs*config['gradient_accumulation_steps']} epochs."
197
+ )
198
+ if config.model_type == 'xlstm' and config["model"]["checkpoint_blocks"]:
199
+ print_zero_rank("Using gradient checkpointing")
200
+ if config["compute_only_fim_loss"]:
201
+ print_zero_rank("Computing only FIM loss for training")
202
+
203
+ # Training callbacks
204
+ es_callback = EarlyStoppingCallback(
205
+ train_path=config["output_dir"] + '/' + wandb_run_name, config=config
206
+ )
207
+ callbacks = [es_callback]
208
+
209
+ # Optimizer and Schedulers
210
+ optimizer, scheduler = set_optimizer_and_scheduler(
211
+ config,
212
+ len(train_dataset),
213
+ model.parameters()
214
+ )
215
+
216
+ # Find checkpoint if available
217
+ last_checkpoint = None
218
+ if config.finetune_model_path is None:
219
+ path = os.path.join(config["output_dir"], wandb_run_name)
220
+ if os.path.exists(path):
221
+ last_checkpoint = get_last_checkpoint(path)
222
+ if last_checkpoint is None:
223
+ print_zero_rank("No checkpoint found, starting training from scratch.")
224
+ else:
225
+ print_zero_rank(f"Resuming training from the last checkpoint: {last_checkpoint}")
226
+
227
+ # Create trainer
228
+ trainer = ProtTrainer(
229
+ model=model,
230
+ train_dataset=train_dataset,
231
+ eval_dataset={"valid": valid_dataset, "train": train_eval_dataset},
232
+ optimizers=(optimizer, scheduler),
233
+ args=TrainingArguments(
234
+ run_name=wandb_run_name,
235
+ local_rank=int(os.getenv('LOCAL_RANK', '0')),
236
+ learning_rate=config["learning_rate"],
237
+ num_train_epochs=config["num_epochs"],
238
+ per_device_train_batch_size=config["batch_size"],
239
+ per_device_eval_batch_size=config["batch_size"],
240
+ gradient_accumulation_steps=config["gradient_accumulation_steps"],
241
+ eval_accumulation_steps=config["eval_accumulation_steps"],
242
+ eval_strategy="steps",
243
+ max_grad_norm=config["max_grad_norm"],
244
+ bf16=config["dtype"] == "bfloat16",
245
+ dataloader_num_workers=32,
246
+ logging_steps=config["logging_steps"],
247
+ eval_steps=config["eval_steps"],
248
+ save_steps=config["save_steps"],
249
+ output_dir=config["output_dir"] + '/' + wandb_run_name,
250
+ logging_dir=config["output_dir"] + '/' + wandb_run_name,
251
+ report_to="wandb" if is_zero_rank() else None,
252
+ log_on_each_node=False,
253
+ overwrite_output_dir=False,
254
+ push_to_hub=False,
255
+ label_names=["labels"],
256
+ ),
257
+ compute_only_fim_loss=config["compute_only_fim_loss"],
258
+ data_collator=data_collator,
259
+ compute_metrics=compute_metrics,
260
+ callbacks=callbacks,
261
+ )
262
+
263
+ # Train model
264
+ while True:
265
+ if last_checkpoint is None and trainer.state.global_step == 0:
266
+ eval_results = trainer.evaluate()
267
+ print_zero_rank(
268
+ f">>> Initial validation perplexity: {eval_results['eval_valid_perplexity/batch']:.2f}"
269
+ )
270
+ else:
271
+ print_zero_rank(f"Resuming training from the last checkpoint: {last_checkpoint}")
272
+ # Train
273
+ trainer.train(resume_from_checkpoint=last_checkpoint)
274
+
275
+ # Break training when the number of epochs is reached
276
+ if (
277
+ not es_callback.should_restart
278
+ or trainer.state.epoch >= config["num_epochs"]
279
+ ):
280
+ eval_results = trainer.evaluate()
281
+ print_zero_rank(
282
+ f">>> Final Perplexity: {eval_results['eval_valid_perplexity/batch']:.2f}"
283
+ )
284
+ break
285
+ # If the training was interrupted because of a loss spike, restart from the last checkpoint
286
+ last_checkpoint = es_callback.checkpoint_path
287
+
288
+ return trainer
289
+
290
+ if __name__ == "__main__":
291
+
292
+ # Default configuration file paths
293
+ default_model_config = "configs/xlstm_default_config.yaml"
294
+ default_train_config = "configs/train_default_config.yaml"
295
+
296
+ parser = argparse.ArgumentParser(
297
+ description="Train or finetune a model with the provided configuration."
298
+ )
299
+ parser.add_argument(
300
+ "--model_config_path",
301
+ type=str,
302
+ default=default_model_config,
303
+ help=f"Path to the model configuration file (default: {default_model_config})"
304
+ )
305
+ parser.add_argument(
306
+ "--train_config_path",
307
+ type=str,
308
+ default=default_train_config,
309
+ help=f"Path to the training and dataset configuration file (default: {default_train_config})"
310
+ )
311
+ parser.add_argument(
312
+ "overrides",
313
+ nargs=argparse.REMAINDER,
314
+ help="Override configuration values using key=value format.",
315
+ )
316
+
317
+ args = parser.parse_args()
318
+
319
+ # Check if the default config files exist, or raise an error
320
+ if not os.path.exists(args.model_config_path):
321
+ raise FileNotFoundError(f"Model config file not found: {args.model_config_path}")
322
+ if not os.path.exists(args.train_config_path):
323
+ raise FileNotFoundError(f"Train config file not found: {args.train_config_path}")
324
+
325
+ # Load the model and training configurations
326
+ model_config = OmegaConf.load(args.model_config_path)
327
+ train_config = OmegaConf.load(args.train_config_path)
328
+
329
+ # Merge the model and training configurations
330
+ config = OmegaConf.merge(model_config, train_config)
331
+
332
+ # Parse overrides
333
+ if args.overrides:
334
+ overrides = parse_override_args(args.overrides)
335
+ config.merge_with(OmegaConf.create(overrides))
336
+
337
+ # Run the training/finetuning process
338
+ run(config)
protxlstm/trainer.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Original code from ProtMamba under Apache License 2.0.
2
+ #
3
+ # Modifications made by Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen
4
+ # - MambaTrainer renamed to ProtTrainer
5
+
6
+ import os
7
+ import re
8
+
9
+ import torch
10
+ from transformers import Trainer, TrainerCallback
11
+
12
+ from protxlstm.utils import AA_TO_ID, find_fim_indices
13
+
14
+ class ProtTrainer(Trainer):
15
+ """
16
+ Base HuggingFace Trainer used for training.
17
+
18
+ from https://github.com/havenhq/mamba-chat/blob/main/trainer/mamba_trainer.py"""
19
+ def __init__(self, compute_only_fim_loss, **kwargs,):
20
+ super().__init__(**kwargs)
21
+ self.compute_only_fim_loss = compute_only_fim_loss
22
+
23
+
24
+ def compute_loss(self, model, inputs, return_outputs=False):
25
+ input_ids = inputs.pop("input_ids")
26
+ labels = inputs.pop("labels")
27
+ if "seq_position_ids" in inputs and "position_ids" in inputs:
28
+ position_ids = inputs.pop("position_ids")
29
+ seq_position_ids = inputs.pop("seq_position_ids")
30
+ output = model(input_ids, position_ids=position_ids, seq_position_ids=seq_position_ids)
31
+ elif "position_ids" in inputs:
32
+ position_ids = inputs.pop("position_ids")
33
+ output = model(input_ids, position_ids=position_ids)
34
+ else:
35
+ output = model(input_ids)
36
+ lm_logits = output.logits
37
+
38
+ labels = labels.to(lm_logits.device)
39
+ shift_logits = lm_logits[:, :-1, :].contiguous()
40
+ labels = labels[:, 1:].contiguous()
41
+
42
+ loss_fct = torch.nn.CrossEntropyLoss()
43
+ if self.compute_only_fim_loss:
44
+ # start and end tokens
45
+ is_cls_tokens = (labels == AA_TO_ID["<cls>"])
46
+ is_eos_tokens = (labels == AA_TO_ID["<eos>"])
47
+ bool_fim = find_fim_indices(is_cls_tokens, is_eos_tokens)
48
+ # include also the cls token
49
+ bool_fim = bool_fim | is_cls_tokens
50
+ inds = torch.where(bool_fim)
51
+ lm_loss = loss_fct(shift_logits[inds[0], inds[1], :], labels[bool_fim])
52
+ else:
53
+ lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
54
+
55
+ return (lm_loss, output) if return_outputs else lm_loss
56
+
57
+ def save_model(self, output_dir, _internal_call):
58
+ if int(os.getenv('LOCAL_RANK', '0')) == 0:
59
+ self.model.save_pretrained(output_dir)
60
+
61
+ PREFIX_CHECKPOINT_DIR = "checkpoint"
62
+ _re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$")
63
+
64
+ def get_last_checkpoint(folder, max_steps=None):
65
+ content = os.listdir(folder)
66
+ checkpoints = [
67
+ path
68
+ for path in content
69
+ if _re_checkpoint.search(path) is not None and os.path.isdir(os.path.join(folder, path))
70
+ ]
71
+ if len(checkpoints) == 0:
72
+ return
73
+
74
+ max_steps = max_steps if max_steps is not None else float("inf")
75
+ # func = lambda x: int(_re_checkpoint.search(x).groups()[0])
76
+ def func(x):
77
+ num = int(_re_checkpoint.search(x).groups()[0])
78
+ return num if num < max_steps else -1
79
+ return os.path.join(folder, max(checkpoints, key=func))
80
+
81
+ class EarlyStoppingCallback(TrainerCallback):
82
+ def __init__(self, train_path, config=None):
83
+ self.step_counter_reset = 0
84
+ self.step_counter_stop = 0
85
+ self.best_loss = None
86
+ self.train_path = train_path
87
+ self.patience = config["patience"]
88
+ self.metric_name = config["early_stopping_metric"]
89
+ self.checkpoint_path = None
90
+ self.should_restart = False
91
+ self.eval_steps = config["eval_steps"]
92
+ self.loss_increase_factor = config["loss_increase_factor"]
93
+
94
+ def get_checkpoint_path(self, max_steps):
95
+ last_checkpoint = None
96
+ if os.path.exists(self.train_path):
97
+ last_checkpoint = get_last_checkpoint(self.train_path, max_steps)
98
+ if last_checkpoint is None:
99
+ print("No checkpoint found, starting training from scratch.")
100
+ else:
101
+ print(f"Max checkpoint allowed: {max_steps}, restarting from {last_checkpoint}.")
102
+ return last_checkpoint
103
+
104
+ def on_evaluate(self, args, state, control, model, metrics, **kwargs):
105
+ if self.metric_name in metrics:
106
+ if self.best_loss is None:
107
+ self.best_loss = metrics[self.metric_name]
108
+ elif self.best_loss*self.loss_increase_factor < metrics[self.metric_name]:
109
+ self.step_counter += 1
110
+ if self.step_counter >= self.patience:
111
+ checkpoint_path = self.get_checkpoint_path(max_steps=(state.global_step-self.patience*self.eval_steps))
112
+ control.should_training_stop = True
113
+ self.checkpoint_path = checkpoint_path
114
+ self.should_restart = True
115
+ else:
116
+ self.step_counter = 0
117
+ self.best_loss = min(self.best_loss, metrics[self.metric_name])
118
+ self.should_restart = False
119
+
120
+ def on_train_begin(self, args, state, control, **kwargs):
121
+ self.step_counter = 0
122
+ self.best_loss = None
123
+ self.should_restart = False
protxlstm/utils.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Some of the objects in this file come from ProtMamba and mamba both under Apache License 2.0.
2
+
3
+ import json
4
+ import os
5
+
6
+ import numpy as np
7
+ import rich
8
+ import torch
9
+ from Bio import SeqIO
10
+ from omegaconf import DictConfig, OmegaConf
11
+ from torch.optim import AdamW
12
+ import wandb
13
+ from transformers import (
14
+ get_constant_schedule_with_warmup,
15
+ get_cosine_schedule_with_warmup,
16
+ get_cosine_with_hard_restarts_schedule_with_warmup,
17
+ )
18
+ from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
19
+ from transformers.utils.hub import cached_file
20
+
21
+ __all__ = ['AA_TO_ID', 'MASK_TO_ID', 'ID_TO_AA', 'load_model', 'encode_sequence', 'decode_sequence', 'clean_sequence', 'tokenizer',
22
+ 'reorder_masked_sequence', 'load_sequences_from_msa_file', 'prepare_dataset_for_fim_generation',
23
+ 'prepare_tokens', 'prepare_target', 'print_number_of_parameters', 'find_fim_indices',
24
+ 'compute_metrics', 'compute_metrics_with_std', 'print_config', 'print_zero_rank', 'is_zero_rank']
25
+
26
+ # Constants
27
+ AA_TO_ID = {'<cls>': 0,
28
+ '<pad>': 1,
29
+ '<eos>': 2,
30
+ '<unk>': 3,
31
+ 'L': 4,
32
+ 'A': 5,
33
+ 'G': 6,
34
+ 'V': 7,
35
+ 'S': 8,
36
+ 'E': 9,
37
+ 'R': 10,
38
+ 'T': 11,
39
+ 'I': 12,
40
+ 'D': 13,
41
+ 'P': 14,
42
+ 'K': 15,
43
+ 'Q': 16,
44
+ 'N': 17,
45
+ 'F': 18,
46
+ 'Y': 19,
47
+ 'M': 20,
48
+ 'H': 21,
49
+ 'W': 22,
50
+ 'C': 23,
51
+ 'X': 24,
52
+ 'B': 25,
53
+ 'U': 26,
54
+ 'Z': 27,
55
+ 'O': 28,
56
+ '.': 29,
57
+ '-': 30,
58
+ '<null_1>': 31,
59
+ '<mask>': 32}
60
+
61
+ MASK_TO_ID = {"<mask-1>": 33,
62
+ "<mask-2>": 34,
63
+ "<mask-3>": 35,
64
+ "<mask-4>": 36,
65
+ "<mask-5>": 37,}
66
+
67
+ AA_TO_ID.update(MASK_TO_ID)
68
+
69
+ ID_TO_AA = {v: k for k, v in AA_TO_ID.items()}
70
+
71
+ # Logging & prints
72
+ def setup_wandb(config):
73
+
74
+ # WandB setup
75
+ os.environ["WANDB_PROJECT"] = config["wandb_project"]
76
+ os.environ["WANDB_ENTITY"] = config["wandb_entity"]
77
+ os.environ["WANDB_MODE"] = config["wandb_mode"]
78
+
79
+ if config['model_type'] == 'xlstm':
80
+ pe = config['model']['add_position_ids']
81
+ pe = 'None' if pe == 'none' else 'AbsPE' if pe == 'abs_1d' else 'AbsPE2' if pe == 'abs_2d' else 'RoPE' if pe == 'rot_1d' else pe == 'rot_2d'
82
+ wandb_run_name = f"{config['model_type']}_l{config['model']['num_blocks']}_d{config['model']['embedding_dim']}_{pe}_s{config['max_msa_len']}_lr{config['learning_rate']}"
83
+ elif config['model_type'] == 'mamba':
84
+ pe = config['model']['add_position_ids']
85
+ pe = 'None' if pe == 'none' else 'AbsPE' if pe == '1d' else pe == '2d'
86
+ wandb_run_name = f"{config['model_type']}_l{config['model']['n_layer']}_d{config['model']['d_model']}_{pe}_s{config['max_msa_len']}_lr{config['learning_rate']}"
87
+ elif config['model_type'] == 'llama':
88
+ pe = 'RoPE'
89
+ wandb_run_name = f"{config['model_type']}_l{config['model']['n_layer']}_d{config['model']['d_model']}_dh{config['model']['hidden_dim']}_{prepare_dataset_for_fim_generation}_s{config['max_msa_len']}_lr{config['learning_rate']}_sched-{config['scheduler']}"
90
+
91
+ if config['name_prefix']:
92
+ wandb_run_name = str(config['name_prefix']) + '_' + wandb_run_name
93
+ if config['name_suffix']:
94
+ wandb_run_name = wandb_run_name + '_' + str(config['name_suffix'])
95
+
96
+ if is_zero_rank():
97
+ wandb.init(
98
+ project=config["wandb_project"],
99
+ entity=config["wandb_entity"],
100
+ mode=config["wandb_mode"],
101
+ name=wandb_run_name)
102
+ config_dict = OmegaConf.to_container(config, resolve=True)
103
+ wandb.config.update(config_dict)
104
+ return wandb_run_name
105
+
106
+ def is_zero_rank():
107
+ return int(os.getenv('LOCAL_RANK', '0')) == 0
108
+
109
+ def print_zero_rank(var):
110
+ if is_zero_rank():
111
+ print(var)
112
+
113
+ def print_number_of_parameters(model):
114
+ num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
115
+ formatted_num_params = f"{num_params:_}"
116
+ print("Number of trainable parameters: ", formatted_num_params)
117
+
118
+ # Sequence tools
119
+ def encode_sequence(sequence):
120
+ """Tokenize a sequence of amino acids and add a cls token at the beginning."""
121
+ tokenized_sequence = [AA_TO_ID[aa] if aa in AA_TO_ID else AA_TO_ID['<unk>'] for aa in sequence]
122
+ return [AA_TO_ID['<cls>']] + tokenized_sequence
123
+
124
+ def decode_sequence(sequence):
125
+ """Decode a sequence of tokens."""
126
+ return "".join([ID_TO_AA[token] if token in ID_TO_AA else "<unk>" for token in sequence])
127
+
128
+ def clean_sequence(sequence):
129
+ """Remove gaps and convert all residues to upper case."""
130
+ return sequence.replace("-", "").upper()
131
+
132
+ def tokenizer(sequence_list, concatenate=True):
133
+ """Tokenize a collection of sequences. If the sequences are aligned, the gaps will be removed
134
+ and the insertions (lower case) will be promoted to upper case."""
135
+ # clean and encode all sequences
136
+ sequence_list = [encode_sequence(clean_sequence(sequence)) for sequence in sequence_list]
137
+ if concatenate:
138
+ # concatenate all sequences
139
+ sequences = np.concatenate(sequence_list)
140
+ # convert to tensor and add batch dimension
141
+ return torch.asarray(sequences, dtype=torch.int8)[None,:]
142
+ else:
143
+ return [torch.asarray(sequence, dtype=torch.int8) for sequence in sequence_list]
144
+
145
+ def reorder_masked_sequence(mask_seq, return_ids=False):
146
+ """
147
+ Reorder a masked sequence to fill the masked positions with the tokens
148
+ that should be there but are positioned after the <eos> token.
149
+ """
150
+ mask_seq = mask_seq.split("<cls>")[0]
151
+ try:
152
+ # Split the sequence and masks
153
+ seq, masks = mask_seq.split("<eos>")
154
+ except:
155
+ return mask_seq
156
+ full_seq = ""
157
+ ids_mask = []
158
+ # Iterate over each mask tag
159
+ for mm in ["<mask-1>", "<mask-2>", "<mask-3>", "<mask-4>", "<mask-5>","<mask-?>"]:
160
+ try:
161
+ # Split the sequence in before and after the mask tag
162
+ seq1, seq2 = seq.split(mm)
163
+ if mm=="<mask-1>":
164
+ # If the mask is the first one, add the sequence before the mask and update the masks
165
+ masks = masks.split("<mask-1>")[1]
166
+ full_seq += seq1
167
+ else:
168
+ # If the mask is not the first one, insert the mask between the two sequence parts
169
+ masks1, masks2 = masks.split(mm)
170
+ ids_mask += [(len(full_seq), len(full_seq)+len(masks1))]
171
+ full_seq += masks1 + seq1
172
+ # Update the masks
173
+ masks = masks2
174
+ # Update the sequence with the part after the mask
175
+ seq = seq2
176
+ except:
177
+ # If the mask is not found, add the remaining sequence
178
+ ids_mask += [(len(full_seq), len(full_seq)+len(masks))]
179
+ full_seq += masks + seq
180
+ break
181
+ if return_ids:
182
+ return full_seq, ids_mask
183
+ return full_seq
184
+
185
+ def load_sequences_from_msa_file(file_path):
186
+ """Load a collection of sequences from an a3m file."""
187
+ with open(file_path, "r") as f:
188
+ sequences = [str(record.seq) for record in SeqIO.parse(f, "fasta")]
189
+ return sequences
190
+
191
+ def prepare_dataset_for_fim_generation(tokens, pos_ids):
192
+ """
193
+ Function to transform the tokenized training dataset into a format that can be used for FIM generation.
194
+ Splits the input tokens and pos_ids into the FIM part (of the last sequence) and the context part (all
195
+ the previous sequences and the masked part of the last sequence).
196
+ Also returns a dictionary with the positions of the mask tokens in the FIM part.
197
+ """
198
+ def find_mask_positions(tokens_fim):
199
+ """
200
+ Function to find the positions of the mask tokens in the FIM part of the last sequence.
201
+ """
202
+ bool_mask = None
203
+ inds_masks = []
204
+ for ind in MASK_TO_ID.values():
205
+ tmp_bool = tokens_fim[0].cpu().numpy() == ind
206
+ bool_mask = tmp_bool if bool_mask is None else bool_mask | tmp_bool
207
+ inds_masks += [ind]
208
+ return bool_mask, inds_masks
209
+ # find where the FIM part of the last sequence starts
210
+ start_last_fim = np.where(tokens[0].cpu().numpy() == AA_TO_ID["<eos>"])[0][-1]
211
+ start_next_seqs = np.where(tokens[0,start_last_fim+1:].cpu().numpy() == AA_TO_ID["<cls>"])[0]
212
+ end_last_fim = start_last_fim+ 1 +start_next_seqs[0] if len(start_next_seqs) > 0 else tokens.shape[1]
213
+ # split tokens and pos_ids into FIM part and context part
214
+ tokens_to_fim = tokens[:,:start_last_fim+1]
215
+ pos_ids_to_fim = pos_ids[:,:start_last_fim+1]
216
+ tokens_fim = tokens[:,start_last_fim+1:end_last_fim]
217
+ pos_ids_fim = pos_ids[:,start_last_fim+1:end_last_fim]
218
+ # find positions of mask tokens
219
+ bool_mask, inds_masks = find_mask_positions(tokens_fim)
220
+ masked_positions = pos_ids_fim[0,bool_mask]
221
+ mask_dict = {ind: int(pos) for ind, pos in zip(inds_masks, masked_positions)}
222
+ return tokens_to_fim, pos_ids_to_fim, tokens_fim, pos_ids_fim, mask_dict
223
+
224
+ # Metrics
225
+ def find_fim_indices(is_cls_tokens, is_eos_tokens):
226
+ """Function to find the indices of the FIM tokens in the sequences.
227
+ """
228
+ # add a cls token at the beginning
229
+ is_cls_tokens = torch.cat([torch.ones_like(is_cls_tokens[:, :1]), is_cls_tokens], dim=1)
230
+ is_eos_tokens = torch.cat([torch.zeros_like(is_eos_tokens[:, :1]), is_eos_tokens], dim=1)
231
+ # both eos and cls tokens
232
+ bol = is_cls_tokens | is_eos_tokens
233
+ tmp = torch.zeros_like(is_cls_tokens, dtype=torch.int)
234
+ tmp[torch.nonzero(is_cls_tokens, as_tuple=True)] = 1
235
+ tmp[torch.nonzero(is_eos_tokens, as_tuple=True)] = -1
236
+ bol1 = torch.clone(bol)
237
+ for batch_ind in range(tmp.size(0)):
238
+ tmp1 = tmp[batch_ind,bol[batch_ind]]
239
+ # find all positions where a 1 if preceeded by a -1
240
+ tmp1 = tmp1[:-1]*tmp1[1:]
241
+ # add the first element to make the sequence start with a 1
242
+ tmp1 = torch.cat([torch.ones_like(tmp1[:1]).to(tmp1.device), tmp1])
243
+ new_bol = tmp1<0
244
+ # bool array True only in the positions where a 1 is preceeded by a -1
245
+ bol1[batch_ind,bol[batch_ind]] = False if new_bol.size(0) == 0 else new_bol
246
+ cumulative_sum = torch.cumsum(bol1, dim=1)
247
+ # Use modulo operation to get the desired tensor
248
+ bol2 = cumulative_sum % 2 == 1
249
+ bol2[is_eos_tokens]= False
250
+ return bol2[:,1:]
251
+
252
+ def compute_metrics(eval_pred):
253
+ predictions, labels = eval_pred
254
+ predictions = torch.tensor(predictions).permute(0, 2, 1)
255
+ labels = torch.tensor(labels)
256
+ # shift labels to align them with predictions and remove last prediction to match the length
257
+ predictions = predictions[:, :, :-1].contiguous()
258
+ labels = labels[:, 1:].contiguous()
259
+ # compute unreduced elementwise loss
260
+ unreduced_loss = torch.nn.functional.cross_entropy(predictions, labels, reduction="none")
261
+ # compute reconstruction accuracy
262
+ reconstruction = (predictions.argmax(1) == labels)
263
+
264
+ # start and end tokens
265
+ is_cls_tokens = (labels == AA_TO_ID["<cls>"])
266
+ is_eos_tokens = (labels == AA_TO_ID["<eos>"])
267
+ # fill in the middle tokens
268
+ if False:
269
+ fim_tokens = torch.zeros(is_cls_tokens.size(0), is_cls_tokens.size(1), dtype=torch.bool)
270
+ in_mask_vector = torch.zeros(is_cls_tokens.size(0), dtype=torch.bool)
271
+ for j in range(is_cls_tokens.size(1)):
272
+ in_mask_vector = in_mask_vector & ~is_cls_tokens[:, j]
273
+ fim_tokens[:, j] = in_mask_vector
274
+ in_mask_vector = in_mask_vector | is_eos_tokens[:, j]
275
+ fim_tokens = find_fim_indices(is_cls_tokens, is_eos_tokens)
276
+
277
+ number_sequences = torch.cumsum(torch.cat([torch.zeros(is_cls_tokens.size(0),1, dtype=torch.int32), is_cls_tokens[:,:-1]],1), -1)
278
+ # fist, second and last sequence tokens
279
+ first_sequence_tokens = ((~fim_tokens & (labels < 33)) | fim_tokens) & (number_sequences == 0)
280
+ second_sequence_tokens = ((~fim_tokens & (labels < 33)) | fim_tokens) & (number_sequences == 1)
281
+ last_sequence_tokens = ((~fim_tokens & (labels < 33)) | fim_tokens) & (number_sequences == (number_sequences.max(1).values[:, None] - 1))
282
+ # end of mask tokens
283
+ end_of_masks = (fim_tokens & (labels > 33)) | is_cls_tokens | is_eos_tokens
284
+
285
+ return {
286
+ "loss/all": torch.mean(unreduced_loss).item(),
287
+ "loss/end_span": torch.mean(unreduced_loss[end_of_masks]).item(),
288
+ "perplexity/seq": torch.mean(torch.exp(torch.mean(unreduced_loss, dim=1))).item(),
289
+ "perplexity/end_span": torch.exp(torch.mean(unreduced_loss[end_of_masks])).item(),
290
+ "perplexity/batch": torch.exp(torch.mean(unreduced_loss)).item(),
291
+ "perplexity/first_seq": torch.exp(torch.mean(unreduced_loss[first_sequence_tokens])).item(),
292
+ "perplexity/second_seq": torch.exp(torch.mean(unreduced_loss[second_sequence_tokens])).item(),
293
+ "perplexity/last_seq": torch.exp(torch.mean(unreduced_loss[last_sequence_tokens])).item(),
294
+ "perplexity/fim": torch.exp(torch.mean(unreduced_loss[fim_tokens])).item(),
295
+ "reconstruction/all": torch.mean(reconstruction.float()).item(),
296
+ "reconstruction/end_span": torch.mean(reconstruction[end_of_masks].float()).item(),
297
+ "reconstruction/first_seq": torch.mean(reconstruction[first_sequence_tokens].float()).item(),
298
+ "reconstruction/second_seq": torch.mean(reconstruction[second_sequence_tokens].float()).item(),
299
+ "reconstruction/last_seq": torch.mean(reconstruction[last_sequence_tokens].float()).item(),
300
+ "reconstruction/fim": torch.mean(reconstruction[fim_tokens].float()).item(),
301
+ }
302
+
303
+ def compute_metrics_with_std(eval_pred):
304
+ predictions, labels = eval_pred
305
+ predictions = torch.tensor(predictions).permute(0, 2, 1)
306
+ labels = torch.tensor(labels)
307
+ # shift labels to align them with predictions and remove last prediction to match the length
308
+ predictions = predictions[:, :, :-1].contiguous()
309
+ labels = labels[:, 1:].contiguous()
310
+ # compute unreduced elementwise loss
311
+ unreduced_loss = torch.nn.functional.cross_entropy(predictions, labels, reduction="none")
312
+ # compute reconstruction accuracy
313
+ reconstruction = (predictions.argmax(1) == labels)
314
+
315
+ # start and end tokens
316
+ is_cls_tokens = (labels == AA_TO_ID["<cls>"])
317
+ is_eos_tokens = (labels == AA_TO_ID["<eos>"])
318
+ # fill in the middle tokens
319
+ if False:
320
+ fim_tokens = torch.zeros(is_cls_tokens.size(0), is_cls_tokens.size(1), dtype=torch.bool)
321
+ in_mask_vector = torch.zeros(is_cls_tokens.size(0), dtype=torch.bool)
322
+ for j in range(is_cls_tokens.size(1)):
323
+ in_mask_vector = in_mask_vector & ~is_cls_tokens[:, j]
324
+ fim_tokens[:, j] = in_mask_vector
325
+ in_mask_vector = in_mask_vector | is_eos_tokens[:, j]
326
+ fim_tokens = find_fim_indices(is_cls_tokens, is_eos_tokens)
327
+
328
+ number_sequences = torch.cumsum(torch.cat([torch.zeros(is_cls_tokens.size(0),1, dtype=torch.int32), is_cls_tokens[:,:-1]],1), -1)
329
+ # fist, second and last sequence tokens
330
+ first_sequence_tokens = ((~fim_tokens & (labels < 33)) | fim_tokens) & (number_sequences == 0)
331
+ second_sequence_tokens = ((~fim_tokens & (labels < 33)) | fim_tokens) & (number_sequences == 1)
332
+ last_sequence_tokens = ((~fim_tokens & (labels < 33)) | fim_tokens) & (number_sequences == (number_sequences.max(1).values[:, None] - 1))
333
+ # end of mask tokens
334
+ end_of_masks = (fim_tokens & (labels > 33)) | is_cls_tokens | is_eos_tokens
335
+
336
+ def perplexities_per_seq_for_subset(unreduced_loss, subset):
337
+ return torch.exp(torch.nanmean(torch.where(subset, unreduced_loss, torch.tensor(float('nan'))), dim=1))
338
+
339
+ return{
340
+ # Loss
341
+ "loss/all": torch.mean(unreduced_loss).item(),
342
+ "loss/std": torch.std(unreduced_loss).item(),
343
+ "loss/end_span": torch.mean(unreduced_loss[end_of_masks]).item(),
344
+ "loss/end_span_std": torch.std(unreduced_loss[end_of_masks]).item(),
345
+
346
+ # Perplexity of all tokens
347
+ "perplexity/batch": torch.exp(torch.mean(unreduced_loss)).item(),
348
+ "perplexity/batch_std": torch.exp(torch.std(unreduced_loss)).item(), # Fix
349
+
350
+ # Perplexity per sequence
351
+ "perplexity/seq": torch.mean(torch.exp(torch.mean(unreduced_loss, dim=1))).item(),
352
+ "perplexity/seq_std": torch.std(torch.exp(torch.mean(unreduced_loss, dim=1))).item(),
353
+ "perplexity/end_span": torch.exp(torch.mean(unreduced_loss[end_of_masks])).item(),
354
+ "perplexity/end_span_std": torch.std(torch.exp(unreduced_loss[end_of_masks])).item(),
355
+
356
+ "perplexity/first_seq": torch.mean(perplexities_per_seq_for_subset(unreduced_loss, first_sequence_tokens)).item(),
357
+ "perplexity/first_seq_std": torch.std(perplexities_per_seq_for_subset(unreduced_loss, first_sequence_tokens)).item(),
358
+ "perplexity/second_seq": torch.mean(perplexities_per_seq_for_subset(unreduced_loss, second_sequence_tokens)).item(),
359
+ "perplexity/second_seq_std": torch.std(perplexities_per_seq_for_subset(unreduced_loss, second_sequence_tokens)).item(),
360
+ "perplexity/last_seq": torch.mean(perplexities_per_seq_for_subset(unreduced_loss, last_sequence_tokens)).item(),
361
+ "perplexity/last_seq_std": torch.std(perplexities_per_seq_for_subset(unreduced_loss, last_sequence_tokens)).item(),
362
+ "perplexity/fim": torch.mean(perplexities_per_seq_for_subset(unreduced_loss, fim_tokens)).item(),
363
+ "perplexity/fim_std": torch.std(perplexities_per_seq_for_subset(unreduced_loss, fim_tokens)).item(),
364
+ "reconstruction/all": torch.mean(reconstruction.float()).item(),
365
+ "reconstruction/std": torch.std(reconstruction.float()).item(),
366
+ "reconstruction/end_span": torch.mean(reconstruction[end_of_masks].float()).item(),
367
+ "reconstruction/end_span_std": torch.std(reconstruction[end_of_masks].float()).item(),
368
+ "reconstruction/first_seq": torch.mean(reconstruction[first_sequence_tokens].float()).item(),
369
+ "reconstruction/first_seq_std": torch.std(reconstruction[first_sequence_tokens].float()).item(),
370
+ "reconstruction/second_seq": torch.mean(reconstruction[second_sequence_tokens].float()).item(),
371
+ "reconstruction/second_seq_std": torch.std(reconstruction[second_sequence_tokens].float()).item(),
372
+ "reconstruction/last_seq": torch.mean(reconstruction[last_sequence_tokens].float()).item(),
373
+ "reconstruction/last_seq_std": torch.std(reconstruction[last_sequence_tokens].float()).item(),
374
+ "reconstruction/fim": torch.mean(reconstruction[fim_tokens].float()).item(),
375
+ "reconstruction/fim_std": torch.std(reconstruction[fim_tokens].float()).item(),
376
+ }
377
+
378
+ # Others
379
+ def set_optimizer_and_scheduler(config, ntrain, parameters):
380
+
381
+ # Set optimizer
382
+ optimizer = AdamW(
383
+ parameters,
384
+ lr=config["learning_rate"],
385
+ betas=(config["beta1"], config["beta2"]),
386
+ weight_decay=config["weight_decay"],
387
+ )
388
+
389
+ eff_batch_size = config["batch_size"] * config["gradient_accumulation_steps"] * torch.cuda.device_count()
390
+
391
+ # Set scheduler
392
+ if config["scheduler"] == "cosine":
393
+ print_zero_rank("Using cosine scheduler")
394
+ scheduler = get_cosine_schedule_with_warmup(
395
+ optimizer,
396
+ num_warmup_steps=config["warmup_steps"],
397
+ num_training_steps=config["num_epochs"] * ntrain // eff_batch_size,
398
+ )
399
+ if config["scheduler"] == "cosine-restarts":
400
+ scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
401
+ optimizer,
402
+ num_warmup_steps=config["warmup_steps"],
403
+ num_training_steps=config["num_epochs"] * ntrain // eff_batch_size,
404
+ num_cycles=config["num_cycles"],
405
+ )
406
+ elif config["scheduler"] == "constant":
407
+ print_zero_rank("Using constant scheduler with warmup")
408
+ scheduler = get_constant_schedule_with_warmup(
409
+ optimizer, num_warmup_steps=config["warmup_steps"]
410
+ )
411
+ else:
412
+ raise ValueError("Scheduler must be either cosine or constant")
413
+
414
+ # Finetuning and no optimizer/scheduler reset
415
+ if config.finetune_model_path and not config.restart_optimizer_and_scheduler:
416
+ optimizer.load_state_dict(torch.load(config.finetune_model_path + "/optimizer.pt"))
417
+ for param_group in optimizer.param_groups:
418
+ param_group['initial_lr'] = config['learning_rate']
419
+ param_group['lr'] = config['learning_rate']
420
+
421
+ scheduler.load_state_dict(torch.load(config.finetune_model_path + "/scheduler.pt"))
422
+ scheduler.base_lrs = [config['learning_rate']]
423
+ scheduler._last_lr = [config['learning_rate']]
424
+
425
+ return optimizer, scheduler
426
+
427
+ def parse_override_args(override_args):
428
+ overrides = {}
429
+ for arg in override_args:
430
+ key, value = arg.split("=")
431
+ keys = key.split(".")
432
+ sub_dict = overrides
433
+ for sub_key in keys[:-1]:
434
+ if sub_key not in sub_dict:
435
+ sub_dict[sub_key] = {}
436
+ sub_dict = sub_dict[sub_key]
437
+ # Convert value to appropriate type
438
+ if value == 'True':
439
+ value = True
440
+ elif value == 'False':
441
+ value = False
442
+ elif value == 'None':
443
+ value = None
444
+ else:
445
+ try:
446
+ value = int(value)
447
+ except ValueError:
448
+ try:
449
+ value = float(value)
450
+ except ValueError:
451
+ pass
452
+ sub_dict[keys[-1]] = value
453
+ return overrides
454
+
455
+ def load_model(
456
+ model_path,
457
+ device,
458
+ model_class,
459
+ dtype=torch.bfloat16,
460
+ **kwargs
461
+ ):
462
+ model = model_class.from_pretrained(
463
+ model_path, device=device, dtype=dtype, **kwargs
464
+ )
465
+ return model
466
+
467
+ # https://github.com/state-spaces/mamba/blob/main/mamba_ssm/utils/hf.py
468
+ def load_config_hf(model_name):
469
+ resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
470
+ return json.load(open(resolved_archive_file))
471
+
472
+ # https://github.com/state-spaces/mamba/blob/main/mamba_ssm/utils/hf.py
473
+ def load_state_dict_hf(model_name, device=None, dtype=None):
474
+ # If not fp32, then we don't want to load directly to the GPU
475
+ mapped_device = "cpu" if dtype not in [torch.float32, None] else device
476
+ resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
477
+ return torch.load(resolved_archive_file, map_location=mapped_device)
478
+ # Convert dtype before moving to GPU to save memory
479
+ if dtype is not None:
480
+ state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
481
+ state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
482
+ return state_dict
protxlstm/xlstm/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .blocks.mlstm.block import mLSTMBlock, mLSTMBlockConfig
2
+ from .blocks.mlstm.layer import mLSTMLayer, mLSTMLayerConfig
3
+ from .components.feedforward import FeedForwardConfig, GatedFeedForward
4
+ from .components.rotary_position import compute_freqs_cis, apply_rotary_emb
5
+ from .xlstm_block_stack import xLSTMBlockStack, xLSTMBlockStackConfig
6
+ from .xlstm_lm_model import xLSTMLMModel, xLSTMLMModelConfig
protxlstm/xlstm/blocks/__init__.py ADDED
File without changes
protxlstm/xlstm/blocks/mlstm/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
protxlstm/xlstm/blocks/mlstm/backends.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) NXAI GmbH and its affiliates 2024
2
+ # Maximilian Beck
3
+
4
+ # Modified by Pieter-Jan Hoedt, Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen
5
+ # - Fix numerical issues between parallel and stepwise backends
6
+ # - Make chunkwise implementation compatible with stepwise backend and variable sequence lengths
7
+
8
+
9
+ import math
10
+ from typing import Union, Tuple, Optional
11
+ import torch
12
+
13
+
14
+ def parallel_stabilized_simple(
15
+ queries: torch.Tensor,
16
+ keys: torch.Tensor,
17
+ values: torch.Tensor,
18
+ igate_preact: torch.Tensor,
19
+ fgate_preact: torch.Tensor,
20
+ lower_triangular_matrix: torch.Tensor = None,
21
+ stabilize_rowwise: bool = True,
22
+ eps: float = 1e-6,
23
+ **kwargs,
24
+ ) -> torch.Tensor:
25
+ """This is the mLSTM cell in parallel form.
26
+ This version is stabilized. We control the range of exp() arguments by
27
+ ensuring that they are always smaller than 0.0 by subtracting the maximum.
28
+
29
+ Args:
30
+ queries (torch.Tensor): (B, NH, S, DH)
31
+ keys (torch.Tensor): (B, NH, S, DH)
32
+ values (torch.Tensor): (B, NH, S, DH)
33
+ igate_preact (torch.Tensor): (B, NH, S, 1)
34
+ fgate_preact (torch.Tensor): (B, NH, S, 1)
35
+ lower_triangular_matrix (torch.Tensor, optional): (S,S). Defaults to None.
36
+ stabilize_rowwise (bool, optional): Wether to stabilize the combination matrix C rowwise (take maximum per row).
37
+ Alternative: Subtract the maximum over all rows. Defaults to True.
38
+
39
+ Returns:
40
+ torch.Tensor: (B, NH, S, DH), h_tilde_state
41
+ """
42
+
43
+ B, NH, S, DH = queries.shape
44
+ _dtype, _device = queries.dtype, queries.device
45
+
46
+ # forget gate matrix
47
+ log_fgates = torch.nn.functional.logsigmoid(fgate_preact) # (B, NH, S, 1)
48
+ if lower_triangular_matrix is None or S < lower_triangular_matrix.size(-1):
49
+ ltr = torch.tril(torch.ones((S, S), dtype=torch.bool, device=_device))
50
+ else:
51
+ ltr = lower_triangular_matrix
52
+ assert (
53
+ ltr.dtype == torch.bool
54
+ ), f"lower_triangular_matrix must be of dtype bool, got {ltr.dtype}"
55
+
56
+ log_f_mat = torch.tril(log_fgates.repeat(1, 1, 1, S), diagonal=-1)
57
+ log_prod_f_mat = torch.cumsum(log_f_mat, dim=-2)
58
+ # Causal masking & selection of the correct submatrix, such that forgetgate at timestep t is not applied
59
+ # to the input at timestep t
60
+ log_fg_matrix = torch.where(ltr, log_prod_f_mat, -float("inf")) # (B, NH, S, S)
61
+
62
+ # gate decay matrix D (combination of forget gate and input gate)
63
+ log_D_matrix = log_fg_matrix + igate_preact.transpose(-2, -1) # (B, NH, S, S)
64
+ # D matrix stabilization
65
+ if stabilize_rowwise:
66
+ max_log_D, _ = torch.max(log_D_matrix, dim=-1, keepdim=True) # (B, NH, S, 1)
67
+ else:
68
+ max_log_D = torch.max(log_D_matrix.view(B, NH, -1), dim=-1, keepdim=True)[
69
+ 0
70
+ ].unsqueeze(-1)
71
+ # (B, NH, 1, 1)
72
+ log_D_matrix_stabilized = log_D_matrix - max_log_D # (B, NH, S, S)
73
+ D_matrix = torch.exp(log_D_matrix_stabilized) # (B, NH, S, S)
74
+
75
+ keys_scaled = keys / math.sqrt(DH)
76
+
77
+ # combination matrix C
78
+ qk_matrix = queries @ keys_scaled.transpose(-2, -1) # (B, NH, S, S)
79
+ C_matrix = qk_matrix * D_matrix # (B, NH, S, S)
80
+ normalizer = torch.maximum(
81
+ C_matrix.sum(dim=-1, keepdim=True).abs(), torch.exp(-max_log_D)
82
+ ) # (B, NH, S, 1)
83
+ # (B, NH, S, S)
84
+ C_matrix_normalized = C_matrix / (normalizer + eps)
85
+
86
+ # retrieved values
87
+ h_tilde_state = C_matrix_normalized @ values # (B, NH, S, DH)
88
+
89
+ return h_tilde_state
90
+
91
+
92
+ def chunkwise_simple(
93
+ queries: torch.Tensor,
94
+ keys: torch.Tensor, # B, NH, S, DH
95
+ values: torch.Tensor, # B, NH, S, DH
96
+ igate_preact: torch.Tensor, # B, NH, S
97
+ fgate_preact: torch.Tensor, # B, NH, S
98
+ initial_C: Optional[torch.Tensor] = None, # B, NH, DH, DH
99
+ initial_n: Optional[torch.Tensor] = None, # B, NH, DH, 1
100
+ initial_m: Optional[torch.Tensor] = None, # B, NH, 1, 1
101
+ chunk_size: int = 64, # optimize this
102
+ return_last_state: bool = False,
103
+ eps: float = 1e-6,
104
+ **kwargs,
105
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
106
+ B, NH, S, DH = queries.shape
107
+ NS, CS = S // chunk_size, chunk_size
108
+ _dtype, _device = queries.dtype, queries.device
109
+
110
+ # form chunks
111
+ q = queries.view(B, NH, NS, CS, DH)
112
+ k = keys.view(B, NH, NS, CS, DH) / math.sqrt(DH)
113
+ v = values.view(B, NH, NS, CS, DH)
114
+
115
+ # forget gates
116
+ log_fgates = torch.nn.functional.logsigmoid(fgate_preact).view(B, NH, NS, CS)
117
+ log_fgates_acc = log_fgates.cumsum(dim=3)
118
+ igate_preact = igate_preact.view(B, NH, NS, CS)
119
+
120
+ log_fgates_rep = log_fgates[:, :, :, :, None].repeat(1, 1, 1, 1, CS)
121
+ log_fg_matrix = torch.tril(log_fgates_rep, diagonal=-1)
122
+ log_prod_fg_matrix = torch.cumsum(log_fg_matrix, dim=3)
123
+
124
+ loggates = (igate_preact + log_prod_fg_matrix[:, :, :, -1]).unsqueeze(-1)
125
+ m_loc, _ = torch.max(loggates, dim=3, keepdim=True)
126
+ loggates = loggates - m_loc
127
+
128
+ kv = k.transpose(-1, -2) @ (v * (loggates).exp())
129
+ ksum = (k * (loggates).exp()).sum(dim=-2)
130
+ C = torch.zeros((B, NH, NS + 1, DH, DH), device=kv.device, dtype=kv.dtype)
131
+ n = torch.zeros((B, NH, NS + 1, DH, 1), device=kv.device, dtype=kv.dtype)
132
+ if initial_C is not None:
133
+ C[:, :, 0] = initial_C
134
+ if initial_n is not None:
135
+ n[:, :, 0] = initial_n
136
+
137
+ m = torch.zeros((B, NH, NS + 1, 1, 1), device=kv.device, dtype=kv.dtype)
138
+ if initial_m is not None:
139
+ m[:, :, 0] = initial_m
140
+
141
+ for i in range(1, NS + 1):
142
+ m[:, :, i] = torch.maximum(
143
+ log_fgates_acc[:, :, i - 1, -1, None, None] + m[:, :, i - 1],
144
+ m_loc[:, :, i - 1],
145
+ )
146
+ C[:, :, i] = (
147
+ C[:, :, i - 1].clone()
148
+ * (
149
+ log_fgates_acc[:, :, i - 1, -1, None, None]
150
+ + m[:, :, i - 1]
151
+ - m[:, :, i]
152
+ ).exp()
153
+ + kv[:, :, i - 1] * (m_loc[:, :, i - 1] - m[:, :, i]).exp()
154
+ )
155
+ n[:, :, i] = (
156
+ n[:, :, i - 1].clone()
157
+ * (
158
+ log_fgates_acc[:, :, i - 1, None, -1:]
159
+ + m[:, :, i - 1]
160
+ - m[:, :, i]
161
+ ).exp()
162
+ + ksum[:, :, i - 1, :, None] * (m_loc[:, :, i - 1] - m[:, :, i]).exp()
163
+ )
164
+
165
+ log_fg_matrix = log_prod_fg_matrix - torch.triu(
166
+ torch.full([1, 1, 1, CS, CS], float("inf")).to(q), diagonal=1
167
+ )
168
+
169
+ # gate decay matrix D (combination of forget gate and input gate)
170
+ log_D_matrix = log_fg_matrix + igate_preact[:, :, :, :, None].transpose(
171
+ -2, -1
172
+ ) # (B, NH, NS, CS, CS)
173
+ D_max, _ = torch.max(log_D_matrix, dim=-1, keepdim=True)
174
+
175
+ stab = torch.maximum(D_max, m[:, :, :-1, :] + log_fgates_acc[:, :, :, :, None])
176
+ inter_C = (
177
+ q * (m[:, :, :-1, :] + log_fgates_acc[:, :, :, :, None] - stab).exp()
178
+ ) @ C[:, :, :-1]
179
+ inter_n = (
180
+ q * (m[:, :, :-1, :] + log_fgates_acc[:, :, :, :, None] - stab).exp()
181
+ ) @ n[:, :, :-1, :]
182
+
183
+ # D matrix stabilization
184
+ log_D_matrix_stabilized = log_D_matrix - stab # (B, NH, NS, CS, CS)
185
+ D_matrix = torch.exp(log_D_matrix_stabilized) # (B, NH, NS, CS, CS)
186
+
187
+ # combination matrix C
188
+ qk_matrix = q @ k.transpose(-2, -1) # (B, NH, NS, CS, CS)
189
+ E_matrix = qk_matrix * D_matrix # (B, NH, NS, CS, CS)
190
+
191
+ normalizer = torch.maximum(
192
+ (E_matrix.sum(dim=-1, keepdim=True) + inter_n).abs(),
193
+ torch.exp(-stab),
194
+ ) # (B, NH, NS, CS, 1)
195
+
196
+ E_matrix_normalized = E_matrix / (normalizer + eps)
197
+
198
+ # retrieved values
199
+ intra = E_matrix_normalized @ v # (B, NH, S, DH)
200
+ inter = inter_C / (normalizer + eps)
201
+
202
+ if return_last_state:
203
+ return (intra + inter).view((B, NH, S, DH)), (C[:, :, -1], n[:, :, -1], m[:, :, -1])
204
+ else:
205
+ return (intra + inter).view((B, NH, S, DH))
206
+
207
+
208
+ # chunkwise backend adapted to handle inputs which are not cleanly divisible by chunk_size
209
+ def chunkwise_variable(
210
+ queries: torch.Tensor,
211
+ keys: torch.Tensor,
212
+ values: torch.Tensor,
213
+ igate_preact: torch.Tensor,
214
+ fgate_preact: torch.Tensor,
215
+ initial_C: Optional[torch.Tensor] = None,
216
+ initial_n: Optional[torch.Tensor] = None,
217
+ initial_m: Optional[torch.Tensor] = None,
218
+ chunk_size: int = 64,
219
+ return_last_state: bool = False,
220
+ eps: float = 1e-6,
221
+ **kwargs,
222
+ ) -> Union[
223
+ torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
224
+ ]:
225
+ """"
226
+ Wrapper around chunkwise_simple to allow sequences with arbitrary lengths
227
+ """
228
+ tail_size = queries.shape[-2] % chunk_size
229
+ if tail_size == 0 or queries.shape[-2] < chunk_size:
230
+ return chunkwise_simple(
231
+ queries, keys, values, igate_preact, fgate_preact,
232
+ initial_C, initial_n, initial_m, chunk_size if tail_size == 0 else tail_size,
233
+ return_last_state, eps, **kwargs
234
+ )
235
+
236
+ sections = [queries.shape[-2] - tail_size, tail_size]
237
+ head_args, tail_args = zip(*(torch.split(x, sections, dim=-2) for x in [
238
+ queries, keys, values, igate_preact, fgate_preact
239
+ ]))
240
+ head_out, state = chunkwise_simple(
241
+ *head_args, initial_C, initial_n, initial_m,
242
+ chunk_size=chunk_size, return_last_state=True, eps=eps, **kwargs
243
+ )
244
+ tail_out = chunkwise_simple(
245
+ *tail_args, *state, chunk_size=tail_size,
246
+ return_last_state=return_last_state, eps=eps, **kwargs
247
+ )
248
+
249
+ if return_last_state:
250
+ return torch.cat([head_out, tail_out[0]], dim=-2), tail_out[-1]
251
+ else:
252
+ return torch.cat([head_out, tail_out], dim=-2)
253
+
254
+
255
+ def recurrent_step_stabilized_simple(
256
+ c_state: torch.Tensor,
257
+ n_state: torch.Tensor,
258
+ m_state: torch.Tensor,
259
+ q: torch.Tensor,
260
+ k: torch.Tensor,
261
+ v: torch.Tensor,
262
+ igate_preact: torch.Tensor,
263
+ fgate_preact: torch.Tensor,
264
+ eps: float = 1e-6,
265
+ **kwargs,
266
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
267
+ """This is a single step of the mLSTM operation in recurrent form.
268
+
269
+ Args:
270
+ c_state (torch.Tensor): (B, NH, DH, DH)
271
+ n_state (torch.Tensor): (B, NH, DH, 1)
272
+ m_state (torch.Tensor): (B, NH, 1, 1)
273
+ q (torch.Tensor): (B, NH, 1, DH)
274
+ k (torch.Tensor): (B, NH, 1, DH)
275
+ v (torch.Tensor): (B, NH, 1, DH)
276
+ igate_preact (torch.Tensor): (B, NH, 1, 1)
277
+ fgate_preact (torch.Tensor): (B, NH, 1, 1)
278
+
279
+ Returns:
280
+ tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
281
+ (hidden_state [B, NH, DH], (c_state_new [B, NH, DH, DH], n_state_new [B, NH, DH, 1]], m_state_new [B, NH, 1, 1]))
282
+ """
283
+ B, NH, S, DH = q.shape
284
+ # projections
285
+ q, k, v = (
286
+ q.squeeze_(2).unsqueeze(-1),
287
+ k.squeeze_(2).unsqueeze(-1),
288
+ v.squeeze_(2).unsqueeze(-1),
289
+ ) # (B, NH, DH, 1)
290
+
291
+ # gates
292
+ log_fg_act = torch.nn.functional.logsigmoid(fgate_preact) # (B, NH, 1, 1)
293
+
294
+ # update rule
295
+ m_state_new = torch.max(log_fg_act + m_state, igate_preact) # (B, NH, 1, 1)
296
+
297
+ fg_act = torch.exp(log_fg_act + m_state - m_state_new) # (B, NH, 1, 1)
298
+ ig_act = torch.exp(igate_preact - m_state_new) # (B, NH, 1, 1)
299
+
300
+ k_scaled = k / math.sqrt(DH)
301
+
302
+ c_state_new = fg_act * c_state + ig_act * (
303
+ k_scaled @ v.transpose(-1, -2)
304
+ ) # (B, NH, DH, DH)
305
+ n_state_new = fg_act * n_state + ig_act * k_scaled # (B, NH, DH, 1)
306
+
307
+ h_num = q.transpose(-1, -2) @ c_state_new # (B, NH, 1, DH)
308
+
309
+ qn_dotproduct = q.transpose(-1, -2) @ n_state_new # (B, NH, 1, 1)
310
+ max_val = torch.exp(-m_state_new) # (B, NH, 1, 1)
311
+ h_denom = torch.maximum(qn_dotproduct.abs(), max_val) + eps
312
+ h = h_num / h_denom # (B, NH, 1, DH) / (B, NH, 1, 1) = (B, NH, 1, DH)
313
+
314
+ return h, (c_state_new, n_state_new, m_state_new)
protxlstm/xlstm/blocks/mlstm/block.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) NXAI GmbH and its affiliates 2024
2
+ # Maximilian Beck
3
+
4
+ # Modified by Pieter-Jan Hoedt, Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen
5
+ # - Remove sLSTM
6
+
7
+
8
+ from dataclasses import dataclass, field
9
+
10
+ from ..xlstm_block import xLSTMBlock, xLSTMBlockConfig
11
+ from .layer import mLSTMLayerConfig
12
+
13
+
14
+ @dataclass
15
+ class mLSTMBlockConfig:
16
+ mlstm: mLSTMLayerConfig = field(default_factory=mLSTMLayerConfig)
17
+
18
+ def __post_init__(self):
19
+ self.mlstm.__post_init__()
20
+
21
+
22
+ class mLSTMBlock(xLSTMBlock):
23
+
24
+ config_class = mLSTMBlockConfig
25
+
26
+ def __init__(self, config: mLSTMBlockConfig) -> None:
27
+ super().__init__(config=xLSTMBlockConfig(mlstm=config.mlstm, feedforward=None))
protxlstm/xlstm/blocks/mlstm/cell.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) NXAI GmbH and its affiliates 2024
2
+ # Maximilian Beck
3
+
4
+ # Modified by Pieter-Jan Hoedt, Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen
5
+ # - Add references to chunkwise backends
6
+ # - Modify forward to take and return state
7
+
8
+
9
+ from dataclasses import dataclass
10
+
11
+ import torch
12
+ from torch import nn
13
+ from functools import partial
14
+
15
+ from ...components.init import bias_linspace_init_
16
+ from ...components.ln import MultiHeadLayerNorm
17
+ from .backends import parallel_stabilized_simple, chunkwise_simple, chunkwise_variable, recurrent_step_stabilized_simple
18
+
19
+
20
+ @dataclass
21
+ class mLSTMCellConfig:
22
+ context_length: int = -1
23
+ embedding_dim: int = -1
24
+ num_heads: int = -1
25
+ backend: str = "parallel" # "chunkwise"
26
+ chunk_size: int = 64
27
+ return_last_state: bool = False
28
+
29
+
30
+ class mLSTMCell(nn.Module):
31
+ config_class = mLSTMCellConfig
32
+
33
+ def __init__(self, config: mLSTMCellConfig):
34
+ super().__init__()
35
+ self.config = config
36
+
37
+ if self.config.return_last_state == True:
38
+ assert config.backend != "parallel", "Parallel backend cannot return state - set return_last_state to False or use a chunkwise backend."
39
+
40
+ if config.backend == "parallel":
41
+ self.backend_fn = parallel_stabilized_simple
42
+ elif config.backend == "chunkwise":
43
+ chunkwise_backend = partial(chunkwise_simple, chunk_size=config.chunk_size, return_last_state=config.return_last_state)
44
+ self.backend_fn = chunkwise_backend
45
+ elif config.backend == "chunkwise_variable":
46
+ chunkwise_backend = partial(chunkwise_variable, chunk_size=config.chunk_size, return_last_state=config.return_last_state)
47
+ self.backend_fn = chunkwise_backend
48
+ else:
49
+ raise ValueError(f"Unknown mLSTM backend: {config.backend}")
50
+ self.backend_fn_step = recurrent_step_stabilized_simple
51
+
52
+ self.igate = nn.Linear(3 * config.embedding_dim, config.num_heads)
53
+ self.fgate = nn.Linear(3 * config.embedding_dim, config.num_heads)
54
+
55
+ self.outnorm = MultiHeadLayerNorm(ndim=config.embedding_dim, weight=True, bias=False)
56
+
57
+ if config.backend == "parallel":
58
+ self.register_buffer(
59
+ "causal_mask",
60
+ torch.tril(torch.ones(config.context_length, config.context_length, dtype=torch.bool)),
61
+ persistent=False,
62
+ )
63
+ else:
64
+ self.causal_mask = None
65
+
66
+ self.reset_parameters()
67
+
68
+ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, state = None, **kwargs) -> torch.Tensor:
69
+ B, S, _ = q.shape # (B, S, H)
70
+
71
+ if_gate_input = torch.cat([q, k, v], dim=-1)
72
+ q = q.view(B, S, self.config.num_heads, -1) # (B, S, NH, DH)
73
+ k = k.view(B, S, self.config.num_heads, -1) # (B, S, NH, DH)
74
+ v = v.view(B, S, self.config.num_heads, -1) # (B, S, NH, DH)
75
+
76
+ q = q.transpose(1, 2) # (B, NH, S, DH)
77
+ k = k.transpose(1, 2) # (B, NH, S, DH)
78
+ v = v.transpose(1, 2) # (B, NH, S, DH)
79
+
80
+ # compute input and forget gate pre-activations
81
+ igate_preact = self.igate(if_gate_input) # (B, S, NH)
82
+ igate_preact = igate_preact.transpose(-1, -2).unsqueeze(-1) # (B, NH, S, 1)
83
+ fgate_preact = self.fgate(if_gate_input) # (B, S, NH)
84
+ fgate_preact = fgate_preact.transpose(-1, -2).unsqueeze(-1) # (B, NH, S, 1)#
85
+
86
+ if state != None and self.config.backend in ["chunkwise", "chunkwise_variable"]:
87
+
88
+ initial_C, initial_n, initial_m = state
89
+
90
+ if self.config.return_last_state:
91
+
92
+ h_state, mlstm_state = self.backend_fn(
93
+ queries=q,
94
+ keys=k,
95
+ values=v,
96
+ igate_preact=igate_preact,
97
+ fgate_preact=fgate_preact,
98
+ initial_C=initial_C,
99
+ initial_n=initial_n,
100
+ initial_m=initial_m,
101
+ lower_triangular_matrix=self.causal_mask,
102
+ )
103
+
104
+ else:
105
+ h_state = self.backend_fn(
106
+ queries=q,
107
+ keys=k,
108
+ values=v,
109
+ igate_preact=igate_preact,
110
+ fgate_preact=fgate_preact,
111
+ initial_C=initial_C,
112
+ initial_n=initial_n,
113
+ initial_m=initial_m,
114
+ lower_triangular_matrix=self.causal_mask,
115
+ ) # (B, NH, S, DH)
116
+
117
+ else:
118
+ if self.config.return_last_state:
119
+ h_state, mlstm_state = self.backend_fn(
120
+ queries=q,
121
+ keys=k,
122
+ values=v,
123
+ igate_preact=igate_preact,
124
+ fgate_preact=fgate_preact,
125
+ lower_triangular_matrix=self.causal_mask,
126
+ )
127
+
128
+ else:
129
+ h_state = self.backend_fn(
130
+ queries=q,
131
+ keys=k,
132
+ values=v,
133
+ igate_preact=igate_preact,
134
+ fgate_preact=fgate_preact,
135
+ lower_triangular_matrix=self.causal_mask,
136
+ ) # (B, NH, S, DH)
137
+
138
+ h_state_norm = self.outnorm(h_state) # (B, NH, S, DH)
139
+ h_state_norm = h_state_norm.transpose(1, 2).reshape(B, S, -1) # (B, NH, S, DH) -> (B, S, NH, DH) -> (B, S, H)
140
+
141
+ if self.config.return_last_state:
142
+ return h_state_norm, mlstm_state
143
+ else:
144
+ return h_state_norm
145
+
146
+ def step(
147
+ self,
148
+ q: torch.Tensor,
149
+ k: torch.Tensor,
150
+ v: torch.Tensor,
151
+ mlstm_state: tuple[torch.Tensor, torch.Tensor, torch.Tensor] = None,
152
+ **kwargs,
153
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
154
+ B, S, _ = q.shape # (B, S, H)
155
+ assert S == 1, f"mLSTMCell.step only supports sequence length S=1, but got S={S}."
156
+
157
+ if_gate_input = torch.cat([q, k, v], dim=-1)
158
+ q = q.view(B, S, self.config.num_heads, -1) # (B, S, NH, DH)
159
+ k = k.view(B, S, self.config.num_heads, -1) # (B, S, NH, DH)
160
+ v = v.view(B, S, self.config.num_heads, -1) # (B, S, NH, DH)
161
+
162
+ _, _, NH, DH = q.shape
163
+
164
+ q = q.transpose(1, 2) # (B, NH, S, DH)
165
+ k = k.transpose(1, 2) # (B, NH, S, DH)
166
+ v = v.transpose(1, 2) # (B, NH, S, DH)
167
+
168
+ # compute input and forget gate pre-activations
169
+ igate_preact = self.igate(if_gate_input) # (B, S, NH)
170
+ igate_preact = igate_preact.transpose(-1, -2).unsqueeze(-1) # (B, NH, S, 1)
171
+ fgate_preact = self.fgate(if_gate_input) # (B, S, NH)
172
+ fgate_preact = fgate_preact.transpose(-1, -2).unsqueeze(-1) # (B, NH, S, 1)
173
+
174
+ if mlstm_state is None:
175
+ c_state = torch.zeros(size=(B, NH, DH, DH), device=q.device, dtype=q.dtype)
176
+ n_state = torch.zeros(size=(B, NH, DH, 1), device=q.device, dtype=q.dtype)
177
+ m_state = torch.zeros(size=(B, NH, 1, 1), device=q.device, dtype=q.dtype)
178
+ else:
179
+ c_state, n_state, m_state = mlstm_state
180
+ c_state = c_state.to(device=q.device, dtype=q.dtype)
181
+ n_state = n_state.to(device=q.device, dtype=q.dtype)
182
+ m_state = m_state.to(device=q.device, dtype=q.dtype)
183
+
184
+ assert c_state.shape == (B, NH, DH, DH), f"Expected c_state shape {(B, NH, DH, DH)}, but got {c_state.shape}."
185
+ assert n_state.shape == (B, NH, DH, 1), f"Expected n_state shape {(B, NH, DH, 1)}, but got {n_state.shape}."
186
+ assert m_state.shape == (B, NH, 1, 1), f"Expected m_state shape {(B, NH, 1, 1)}, but got {m_state.shape}."
187
+
188
+ h_state, mlstm_state = self.backend_fn_step(
189
+ c_state=c_state,
190
+ n_state=n_state,
191
+ m_state=m_state,
192
+ q=q,
193
+ k=k,
194
+ v=v,
195
+ igate_preact=igate_preact,
196
+ fgate_preact=fgate_preact,
197
+ ) # (B, NH, 1 DH), ((B, NH, DH, DH), (B, NH, DH, 1), (B, NH, 1, 1))
198
+
199
+ h_state_norm = self.outnorm(h_state) # (B, NH, S, DH)
200
+ h_state_norm = h_state_norm.transpose(1, 2).reshape(B, S, -1) # (B, NH, S, DH) -> (B, S, NH, DH) -> (B, S, H)
201
+
202
+ return h_state_norm, mlstm_state
203
+
204
+
205
+ def reset_parameters(self):
206
+ self.outnorm.reset_parameters()
207
+ # forget gate initialization
208
+ torch.nn.init.zeros_(self.fgate.weight)
209
+ bias_linspace_init_(self.fgate.bias, start=3.0, end=6.0)
210
+ # input gate initialization
211
+ torch.nn.init.zeros_(self.igate.weight)
212
+ torch.nn.init.normal_(self.igate.bias, mean=0.0, std=0.1)
protxlstm/xlstm/blocks/mlstm/layer.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) NXAI GmbH and its affiliates 2024
2
+ # Maximilian Beck
3
+
4
+ # Modified by Pieter-Jan Hoedt, Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen
5
+ # - Modify forward to take and return state
6
+
7
+
8
+ from dataclasses import dataclass
9
+
10
+ import torch
11
+ from torch import nn
12
+
13
+ from ...components.conv import CausalConv1d, CausalConv1dConfig
14
+ from ...components.init import small_init_init_, wang_init_
15
+ from ...components.linear_headwise import (
16
+ LinearHeadwiseExpand,
17
+ LinearHeadwiseExpandConfig,
18
+ )
19
+ from ...utils import UpProjConfigMixin
20
+ from ...components.rotary_position import apply_rotary_emb
21
+ from .cell import mLSTMCell, mLSTMCellConfig
22
+
23
+
24
+ @dataclass
25
+ class mLSTMLayerConfig(UpProjConfigMixin):
26
+ conv1d_kernel_size: int = 4
27
+ qkv_proj_blocksize: int = 4
28
+ num_heads: int = 4
29
+ proj_factor: float = 2.0
30
+
31
+ # will be set toplevel config
32
+ embedding_dim: int = -1
33
+ bias: bool = False
34
+ dropout: float = 0.0
35
+ context_length: int = -1
36
+ backend: str = "parallel" # "chunkwise"
37
+ chunk_size: int = 64
38
+ return_last_state: bool = False
39
+
40
+ _num_blocks: int = 1
41
+ _inner_embedding_dim: int = None
42
+
43
+ def __post_init__(self):
44
+ self._set_proj_up_dim(embedding_dim=self.embedding_dim)
45
+ self._inner_embedding_dim = self._proj_up_dim
46
+
47
+
48
+ class mLSTMLayer(nn.Module):
49
+ config_class = mLSTMLayerConfig
50
+
51
+ def __init__(self, config: mLSTMLayerConfig):
52
+ super().__init__()
53
+ self.config = config
54
+
55
+ self.proj_up = nn.Linear(
56
+ in_features=self.config.embedding_dim,
57
+ out_features=2 * self.config._inner_embedding_dim,
58
+ bias=self.config.bias,
59
+ )
60
+
61
+ num_proj_heads = round(self.config._inner_embedding_dim // self.config.qkv_proj_blocksize)
62
+ self.q_proj = LinearHeadwiseExpand(
63
+ config=LinearHeadwiseExpandConfig(
64
+ in_features=self.config._inner_embedding_dim,
65
+ num_heads=num_proj_heads,
66
+ bias=self.config.bias,
67
+ )
68
+ )
69
+ self.k_proj = LinearHeadwiseExpand(
70
+ config=LinearHeadwiseExpandConfig(
71
+ in_features=self.config._inner_embedding_dim,
72
+ num_heads=num_proj_heads,
73
+ bias=self.config.bias,
74
+ )
75
+ )
76
+ self.v_proj = LinearHeadwiseExpand(
77
+ config=LinearHeadwiseExpandConfig(
78
+ in_features=self.config._inner_embedding_dim,
79
+ num_heads=num_proj_heads,
80
+ bias=self.config.bias,
81
+ )
82
+ )
83
+
84
+ self.conv1d = CausalConv1d(
85
+ config=CausalConv1dConfig(
86
+ feature_dim=self.config._inner_embedding_dim,
87
+ kernel_size=self.config.conv1d_kernel_size,
88
+ )
89
+ )
90
+ self.conv_act_fn = nn.SiLU()
91
+ self.mlstm_cell = mLSTMCell(
92
+ config=mLSTMCellConfig(
93
+ context_length=self.config.context_length,
94
+ embedding_dim=self.config._inner_embedding_dim,
95
+ num_heads=self.config.num_heads,
96
+ backend=self.config.backend,
97
+ chunk_size=self.config.chunk_size,
98
+ return_last_state = self.config.return_last_state
99
+ )
100
+ )
101
+ self.ogate_act_fn = nn.SiLU()
102
+
103
+ self.learnable_skip = nn.Parameter(torch.ones(self.config._inner_embedding_dim, requires_grad=True))
104
+
105
+ self.proj_down = nn.Linear(
106
+ in_features=self.config._inner_embedding_dim,
107
+ out_features=self.config.embedding_dim,
108
+ bias=self.config.bias,
109
+ )
110
+ self.dropout = nn.Dropout(self.config.dropout)
111
+ self.reset_parameters()
112
+
113
+ def forward(self, x: torch.Tensor, freqs_cos=None, freqs_sin=None, state=None, **kwargs) -> torch.Tensor:
114
+ B, S, _ = x.shape
115
+
116
+ # up-projection
117
+ x_inner = self.proj_up(x)
118
+ x_mlstm, z = torch.split(x_inner, split_size_or_sections=self.config._inner_embedding_dim, dim=-1)
119
+
120
+ # mlstm branch
121
+ if state != None:
122
+ mlstm_state = state["mlstm_state"]
123
+ conv_state = state["conv_state"][0]
124
+ else:
125
+ mlstm_state, conv_state = None, None
126
+
127
+ if self.config.return_last_state:
128
+ x_mlstm_conv, conv_state = self.conv1d(x_mlstm, conv_state = conv_state, return_last_state = True)
129
+ else:
130
+ x_mlstm_conv = self.conv1d(x_mlstm, conv_state = conv_state)
131
+
132
+ x_mlstm_conv_act = self.conv_act_fn(x_mlstm_conv)
133
+
134
+ q = self.q_proj(x_mlstm_conv_act)
135
+ k = self.k_proj(x_mlstm_conv_act)
136
+ v = self.v_proj(x_mlstm)
137
+
138
+ if freqs_cos is not None and freqs_sin is not None:
139
+ q, k = apply_rotary_emb(q, k, freqs_cos, freqs_sin)
140
+
141
+ if self.config.return_last_state:
142
+ h_tilde_state, mlstm_state = self.mlstm_cell(q=q, k=k, v=v, state=mlstm_state, **kwargs)
143
+ else:
144
+ h_tilde_state = self.mlstm_cell(q=q, k=k, v=v, state=mlstm_state, **kwargs)
145
+
146
+ h_tilde_state_skip = h_tilde_state + (self.learnable_skip * x_mlstm_conv_act)
147
+
148
+ # output / z branch
149
+ h_state = h_tilde_state_skip * self.ogate_act_fn(z)
150
+
151
+ # down-projection
152
+ y = self.dropout(self.proj_down(h_state))
153
+
154
+ if self.config.return_last_state:
155
+ return y, {"mlstm_state": mlstm_state, "conv_state": (conv_state,)}
156
+ else:
157
+ return y
158
+
159
+ def step(
160
+ self,
161
+ x: torch.Tensor,
162
+ freqs_cos=None,
163
+ freqs_sin=None,
164
+ mlstm_state: tuple[torch.Tensor, torch.Tensor, torch.Tensor] = None,
165
+ conv_state: tuple[torch.Tensor] = None,
166
+ ) -> tuple[torch.Tensor, dict[str, tuple[torch.Tensor, ...]]]:
167
+ B, S, _ = x.shape
168
+
169
+ # up-projection
170
+ x_inner = self.proj_up(x)
171
+ x_mlstm, z = torch.split(x_inner, split_size_or_sections=self.config._inner_embedding_dim, dim=-1)
172
+
173
+ # mlstm branch
174
+ x_mlstm_conv, conv_state = self.conv1d.step(x_mlstm, conv_state=conv_state)
175
+ x_mlstm_conv_act = self.conv_act_fn(x_mlstm_conv)
176
+
177
+ q = self.q_proj(x_mlstm_conv_act)
178
+ k = self.k_proj(x_mlstm_conv_act)
179
+ v = self.v_proj(x_mlstm)
180
+
181
+ if freqs_cos is not None and freqs_sin is not None:
182
+ q, k = apply_rotary_emb(q, k, freqs_cos, freqs_sin)
183
+
184
+ h_tilde_state, mlstm_state = self.mlstm_cell.step(q=q, k=k, v=v, mlstm_state=mlstm_state)
185
+
186
+ h_tilde_state_skip = h_tilde_state + (self.learnable_skip * x_mlstm_conv_act)
187
+
188
+ # output / z branch
189
+ h_state = h_tilde_state_skip * self.ogate_act_fn(z)
190
+
191
+ # down-projection
192
+ y = self.dropout(self.proj_down(h_state))
193
+ return y, {"mlstm_state": mlstm_state, "conv_state": conv_state}
194
+
195
+ def reset_parameters(self):
196
+ # init inproj
197
+ small_init_init_(self.proj_up.weight, dim=self.config.embedding_dim)
198
+ if self.proj_up.bias is not None:
199
+ nn.init.zeros_(self.proj_up.bias)
200
+ # init outproj
201
+ wang_init_(self.proj_down.weight, dim=self.config.embedding_dim, num_blocks=self.config._num_blocks)
202
+ if self.proj_down.bias is not None:
203
+ nn.init.zeros_(self.proj_down.bias)
204
+
205
+ nn.init.ones_(self.learnable_skip)
206
+
207
+ def _init_qkv_proj(qkv_proj: LinearHeadwiseExpand):
208
+ # use the embedding dim instead of the inner embedding dim
209
+ small_init_init_(qkv_proj.weight, dim=self.config.embedding_dim)
210
+ if qkv_proj.bias is not None:
211
+ nn.init.zeros_(qkv_proj.bias)
212
+
213
+ _init_qkv_proj(self.q_proj)
214
+ _init_qkv_proj(self.k_proj)
215
+ _init_qkv_proj(self.v_proj)
216
+
217
+ self.mlstm_cell.reset_parameters()
protxlstm/xlstm/blocks/xlstm_block.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) NXAI GmbH and its affiliates 2024
2
+ # Maximilian Beck
3
+
4
+ # Modified by Pieter-Jan Hoedt, Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen
5
+ # - Remove sLSTM
6
+ # - Modify forward to take and return state
7
+
8
+
9
+ from dataclasses import dataclass
10
+ from typing import Optional
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+ from ..components.feedforward import FeedForwardConfig, create_feedforward
16
+ from ..components.ln import LayerNorm
17
+ from .mlstm.layer import mLSTMLayer, mLSTMLayerConfig
18
+
19
+ """An xLSTM block can be either an sLSTM Block or an mLSTM Block.
20
+ In this repository only mLSTM is implemented.
21
+
22
+ It contains the pre-LayerNorms and the skip connections.
23
+ """
24
+
25
+
26
+ @dataclass
27
+ class xLSTMBlockConfig:
28
+ mlstm: Optional[mLSTMLayerConfig] = None
29
+
30
+ feedforward: Optional[FeedForwardConfig] = None
31
+
32
+ _num_blocks: int = 1
33
+ _block_idx: int = 0
34
+
35
+ def __post_init__(self):
36
+ assert (
37
+ self.mlstm is not None
38
+ ), "mlstm config must be provided"
39
+
40
+ embedding_dim = (
41
+ self.mlstm.embedding_dim
42
+ )
43
+
44
+ self.mlstm._num_blocks = self._num_blocks
45
+ self.mlstm._block_idx = self._block_idx
46
+
47
+ if self.feedforward:
48
+ self.feedforward.embedding_dim = embedding_dim
49
+ self.feedforward._num_blocks = self._num_blocks
50
+ self.feedforward.__post_init__()
51
+
52
+
53
+ class xLSTMBlock(nn.Module):
54
+
55
+ config_class = xLSTMBlockConfig
56
+
57
+ def __init__(self, config: xLSTMBlockConfig) -> None:
58
+ super().__init__()
59
+ self.config = config
60
+ embedding_dim = (
61
+ self.config.mlstm.embedding_dim
62
+ )
63
+
64
+ self.xlstm_norm = LayerNorm(ndim=embedding_dim, weight=True, bias=False)
65
+
66
+ if self.config.mlstm is not None:
67
+ self.xlstm = mLSTMLayer(config=self.config.mlstm)
68
+ else:
69
+ raise ValueError("mlstm must be provided")
70
+
71
+ if self.config.feedforward is not None:
72
+ self.ffn_norm = LayerNorm(
73
+ ndim=self.config.feedforward.embedding_dim, weight=True, bias=False
74
+ )
75
+ self.ffn = create_feedforward(config=self.config.feedforward)
76
+ else:
77
+ self.ffn_norm = None
78
+ self.ffn = None
79
+
80
+ self.reset_parameters()
81
+
82
+ def forward(self, x: torch.Tensor, state=None, **kwargs) -> torch.Tensor:
83
+ if self.config.mlstm.return_last_state:
84
+ x_xlstm, xlstm_state = self.xlstm(self.xlstm_norm(x), state=state, **kwargs)
85
+ x = x + x_xlstm
86
+ else:
87
+ x = x + self.xlstm(self.xlstm_norm(x), state=state, **kwargs)
88
+
89
+ if self.ffn is not None:
90
+ x = x + self.ffn(self.ffn_norm(x), **kwargs)
91
+
92
+ if self.config.mlstm.return_last_state:
93
+ return x, xlstm_state
94
+ else:
95
+ return x
96
+
97
+ def step(self, x: torch.Tensor, **kwargs) -> tuple[torch.Tensor, dict[str, tuple[torch.Tensor, ...]]]:
98
+ x_xlstm, xlstm_state = self.xlstm.step(self.xlstm_norm(x), **kwargs)
99
+ x = x + x_xlstm
100
+ if self.ffn is not None:
101
+ x = x + self.ffn(self.ffn_norm(x), **kwargs)
102
+ return x, xlstm_state
103
+
104
+ def reset_parameters(self) -> None:
105
+
106
+ self.xlstm.reset_parameters()
107
+ self.xlstm_norm.reset_parameters()
108
+
109
+ if self.ffn is not None:
110
+ self.ffn.reset_parameters()
111
+ self.ffn_norm.reset_parameters()
protxlstm/xlstm/components/__init__.py ADDED
File without changes
protxlstm/xlstm/components/conv.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) NXAI GmbH and its affiliates 2024
2
+ # Maximilian Beck, Korbinian Pöppel
3
+
4
+ # Modified by Pieter-Jan Hoedt, Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen
5
+ # - Modify forward to take and return state
6
+
7
+
8
+ from dataclasses import dataclass, field
9
+ from typing import Optional
10
+
11
+ import torch
12
+
13
+ # from einops import rearrange
14
+ from torch import nn
15
+
16
+
17
+ @dataclass
18
+ class CausalConv1dConfig:
19
+ feature_dim: int = None # F
20
+ kernel_size: int = 4
21
+ causal_conv_bias: bool = True
22
+ channel_mixing: bool = False
23
+ conv1d_kwargs: dict = field(default_factory=dict)
24
+
25
+ def __post_init__(self):
26
+ assert self.kernel_size >= 0, "kernel_size must be >= 0"
27
+
28
+
29
+ def conv1d_step(
30
+ x: torch.Tensor,
31
+ conv_state: torch.Tensor,
32
+ conv1d_weight: torch.Tensor,
33
+ conv1d_bias: torch.Tensor = None,
34
+ ) -> tuple[torch.Tensor, torch.Tensor]:
35
+ """
36
+ B: batch size
37
+ S: sequence length
38
+ D: feature dimension
39
+ KS: kernel size
40
+ Args:
41
+ x (torch.Tensor): (B, S, D)
42
+ conv_state (torch.Tensor): (B, KS, D)
43
+ conv1d_weight (torch.Tensor): (KS, D)
44
+ """
45
+ assert (
46
+ x.shape[0] == conv_state.shape[0]
47
+ ), f"x has batch size {x.shape[0]} but conv_state has batch size {conv_state.shape[0]}"
48
+ assert (
49
+ x.shape[2] == conv_state.shape[2]
50
+ ), f"x has feature dimension {x.shape[2]} but conv_state has feature dimension {conv_state.shape[2]}"
51
+ assert x.shape[1] == 1, f"x has sequence length {x.shape[1]} but it should be 1"
52
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=1))
53
+ conv_state[:, -1:, :] = x
54
+ y = torch.sum(conv_state * conv1d_weight, dim=1, keepdim=True)
55
+ if conv1d_bias is not None:
56
+ y += conv1d_bias
57
+ return y, conv_state
58
+
59
+
60
+ class CausalConv1d(nn.Module):
61
+ config_class = CausalConv1dConfig
62
+ """
63
+ Implements causal depthwise convolution of a time series tensor.
64
+ Input: Tensor of shape (B,T,F), i.e. (batch, time, feature)
65
+ Output: Tensor of shape (B,T,F)
66
+
67
+ Args:
68
+ feature_dim: number of features in the input tensor
69
+ kernel_size: size of the kernel for the depthwise convolution
70
+ causal_conv_bias: whether to use bias in the depthwise convolution
71
+ channel_mixing: whether to use channel mixing (i.e. groups=1) or not (i.e. groups=feature_dim)
72
+ If True, it mixes the convolved features across channels.
73
+ If False, all the features are convolved independently.
74
+ """
75
+
76
+ def __init__(self, config: CausalConv1dConfig):
77
+ super().__init__()
78
+ self.config = config
79
+ self.groups = self.config.feature_dim
80
+ if self.config.channel_mixing:
81
+ self.groups = 1
82
+ if self.config.kernel_size == 0:
83
+ self.conv = None # Noop
84
+ else:
85
+ self.pad = (
86
+ self.config.kernel_size - 1
87
+ ) # padding of this size assures temporal causality.
88
+ self.conv = nn.Conv1d(
89
+ in_channels=self.config.feature_dim,
90
+ out_channels=self.config.feature_dim,
91
+ kernel_size=self.config.kernel_size,
92
+ padding=self.pad,
93
+ groups=self.groups,
94
+ bias=self.config.causal_conv_bias,
95
+ **self.config.conv1d_kwargs,
96
+ )
97
+ # B, C, L
98
+ self.reset_parameters()
99
+
100
+ def reset_parameters(self, **kwargs):
101
+ self.conv.reset_parameters()
102
+
103
+ def _create_weight_decay_optim_groups(
104
+ self,
105
+ ) -> tuple[set[nn.Parameter], set[nn.Parameter]]:
106
+ if self.config.kernel_size == 0:
107
+ return (), ()
108
+ else:
109
+ weight_decay = (self.conv.weight,)
110
+ no_weight_decay = ()
111
+ if self.config.causal_conv_bias:
112
+ no_weight_decay += (self.conv.bias,)
113
+ return weight_decay, no_weight_decay
114
+
115
+ def forward(
116
+ self,
117
+ x: torch.Tensor,
118
+ conv_state: Optional[torch.Tensor] = None,
119
+ return_last_state: bool = False,
120
+ ) -> torch.Tensor:
121
+ if conv_state is not None:
122
+ conv_state = conv_state[:,-self.pad:]
123
+ x = torch.cat([conv_state, x], dim=1)
124
+
125
+ if self.config.kernel_size == 0:
126
+ return x
127
+ y = x.transpose(2, 1) # (B,F,T) tensor - now in the right shape for conv layer.
128
+ y = self.conv(y) # (B,F,T+pad) tensor
129
+ if conv_state is not None:
130
+ y = y[:, :, conv_state.shape[1] :]
131
+
132
+ if return_last_state:
133
+ return y[:, :, : -self.pad].transpose(2, 1), x[:, -self.config.kernel_size:] #[:, -self.pad :]
134
+ else:
135
+ return y[:, :, : -self.pad].transpose(2, 1)
136
+
137
+ def step(
138
+ self,
139
+ x: torch.Tensor,
140
+ conv_state: tuple[torch.Tensor] = None,
141
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor]]:
142
+
143
+ if self.config.kernel_size == 0:
144
+ return x, conv_state
145
+
146
+ B, S, D = x.shape
147
+
148
+ if conv_state is None:
149
+ conv_state = (
150
+ torch.zeros(
151
+ size=(B, self.config.kernel_size, D),
152
+ device=self.conv.weight.device,
153
+ dtype=self.conv.weight.dtype,
154
+ ),
155
+ )
156
+
157
+ y, conv_state = conv1d_step(
158
+ x,
159
+ conv_state[0],
160
+ self.conv.weight[:, 0, :].transpose(0, 1), # rearrange(, "D 1 KS -> KS D")
161
+ conv1d_bias=self.conv.bias if self.config.causal_conv_bias else None,
162
+ )
163
+ return y, (conv_state,)
protxlstm/xlstm/components/feedforward.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) NXAI GmbH and its affiliates 2024
2
+ # Maximilian Beck
3
+ from dataclasses import dataclass
4
+ from typing import Callable, Literal
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+ from ..utils import UpProjConfigMixin
10
+ from .init import small_init_init_, wang_init_
11
+
12
+ _act_fn_registry = {
13
+ "gelu": nn.functional.gelu,
14
+ "relu": nn.functional.relu,
15
+ "relu^2": lambda x: torch.square(nn.functional.relu(x)),
16
+ "sigmoid": nn.functional.sigmoid,
17
+ "swish": nn.functional.silu,
18
+ "selu": nn.functional.selu,
19
+ }
20
+
21
+
22
+ def get_act_fn(act_fn_name: str) -> Callable[[torch.Tensor], torch.Tensor]:
23
+ if act_fn_name in _act_fn_registry:
24
+ return _act_fn_registry[act_fn_name]
25
+ else:
26
+ assert (
27
+ False
28
+ ), f'Unknown activation function name "{act_fn_name}". Available activation functions are: {str(_act_fn_cls_registry.keys())}'
29
+
30
+
31
+ @dataclass
32
+ class FeedForwardConfig(UpProjConfigMixin):
33
+ proj_factor: float = 1.3
34
+ act_fn: str = "gelu"
35
+ embedding_dim: int = -1
36
+ dropout: float = 0.0
37
+ bias: bool = False
38
+ ff_type: Literal["ffn_gated"] = "ffn_gated"
39
+
40
+ _num_blocks: int = 1
41
+
42
+ def __post_init__(self):
43
+ self._set_proj_up_dim(embedding_dim=self.embedding_dim)
44
+ assert self.act_fn in _act_fn_registry, f"Unknown activation function {self.act_fn}"
45
+
46
+
47
+ class GatedFeedForward(nn.Module):
48
+ config_class = FeedForwardConfig
49
+
50
+ def __init__(self, config: FeedForwardConfig):
51
+ super().__init__()
52
+ self.config = config
53
+
54
+ self.proj_up = nn.Linear(
55
+ in_features=self.config.embedding_dim,
56
+ out_features=2 * self.config._proj_up_dim,
57
+ bias=self.config.bias,
58
+ )
59
+ self.proj_down = nn.Linear(
60
+ in_features=self.config._proj_up_dim,
61
+ out_features=self.config.embedding_dim,
62
+ bias=self.config.bias,
63
+ )
64
+
65
+ self.act_fn = get_act_fn(self.config.act_fn)
66
+
67
+ self.dropout = nn.Dropout(self.config.dropout)
68
+ self.reset_parameters()
69
+
70
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
71
+ gate_preact, up_proj = self.proj_up(x).split(self.config._proj_up_dim, dim=-1)
72
+ x = self.dropout(self.proj_down(self.act_fn(gate_preact) * up_proj))
73
+ return x
74
+
75
+ def reset_parameters(self):
76
+ small_init_init_(self.proj_up.weight, dim=self.config.embedding_dim)
77
+ if self.proj_up.bias is not None:
78
+ nn.init.zeros_(self.proj_up.bias)
79
+ wang_init_(self.proj_down.weight, dim=self.config.embedding_dim, num_blocks=self.config._num_blocks)
80
+ if self.proj_down.bias is not None:
81
+ nn.init.zeros_(self.proj_down.bias)
82
+
83
+
84
+ def create_feedforward(config: FeedForwardConfig) -> nn.Module:
85
+ if config.ff_type == "ffn_gated":
86
+ return GatedFeedForward(config)
87
+ else:
88
+ raise ValueError(f"Unknown feedforward type {config.ff_type}")
protxlstm/xlstm/components/init.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) NXAI GmbH and its affiliates 2024
2
+ # Maximilian Beck
3
+ import math
4
+
5
+ import torch
6
+
7
+
8
+ def bias_linspace_init_(param: torch.Tensor, start: float = 3.4, end: float = 6.0) -> torch.Tensor:
9
+ """Linearly spaced bias init across dimensions."""
10
+ assert param.dim() == 1, f"param must be 1-dimensional (typically a bias), got {param.dim()}"
11
+ n_dims = param.shape[0]
12
+ init_vals = torch.linspace(start, end, n_dims)
13
+ with torch.no_grad():
14
+ param.copy_(init_vals)
15
+ return param
16
+
17
+
18
+ def small_init_init_(param: torch.Tensor, dim: int) -> torch.Tensor:
19
+ """Fills the input Tensor with values according to the method described in Transformers without Tears: Improving
20
+ the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2019), using a normal distribution.
21
+ Adopted from https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py.
22
+ """
23
+ std = math.sqrt(2 / (5 * dim))
24
+ torch.nn.init.normal_(param, mean=0.0, std=std)
25
+ return param
26
+
27
+
28
+ def wang_init_(param: torch.Tensor, dim: int, num_blocks: int):
29
+ """Adopted from https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py."""
30
+ std = 2 / num_blocks / math.sqrt(dim)
31
+ torch.nn.init.normal_(param, mean=0.0, std=std)
32
+ return param
protxlstm/xlstm/components/linear_headwise.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) NXAI GmbH and its affiliates 2024
2
+ # Maximilian Beck, Korbininan Pöppel
3
+ from dataclasses import dataclass
4
+
5
+ from math import sqrt
6
+ import torch
7
+
8
+ # from einops import einsum, rearrange
9
+ from torch import nn
10
+
11
+
12
+ @dataclass
13
+ class LinearHeadwiseExpandConfig:
14
+ in_features: int = 0
15
+ # this is the number of heads that the in_features are split into
16
+ # if num_heads=1, this is a normal linear layer
17
+ # if num_heads>1, the in_features are split into num_heads and each head is projected separately
18
+ # if num_heads=in_features, each feature is projected separately
19
+ num_heads: int = -1
20
+ expand_factor_up: float = 1
21
+
22
+ # this is internally computed
23
+ # but can be overwritten if you want to use a different output dimension
24
+ # if > 0 the expand factor is ignored
25
+ _out_features: int = -1
26
+
27
+ bias: bool = True
28
+ trainable_weight: bool = True
29
+ trainable_bias: bool = True
30
+
31
+ def __post_init__(self):
32
+ assert self.num_heads > 0, "num_heads must be set"
33
+ assert self.num_heads <= self.in_features, "num_heads must be <= in_features"
34
+ assert (
35
+ self.in_features % self.num_heads == 0
36
+ ), "in_features must be a multiple of num_heads"
37
+
38
+ if self._out_features < 0:
39
+ self._out_features = round(self.expand_factor_up * self.in_features)
40
+
41
+
42
+ class LinearHeadwiseExpand(nn.Module):
43
+ """This is a structured projection layer that projects the input to a higher dimension.
44
+ It only allows integer up-projection factors, i.e. the output dimension is a multiple of the input dimension.
45
+ """
46
+
47
+ config_class = LinearHeadwiseExpandConfig
48
+
49
+ def __init__(self, config: LinearHeadwiseExpandConfig):
50
+ super().__init__()
51
+ self.config = config
52
+ in_features = self.config.in_features
53
+ num_heads = self.config.num_heads
54
+ out_features_per_head = config._out_features // num_heads
55
+ self.weight = nn.Parameter(
56
+ torch.empty(num_heads, out_features_per_head, in_features // num_heads),
57
+ requires_grad=config.trainable_weight,
58
+ )
59
+ if config.bias:
60
+ self.bias = nn.Parameter(
61
+ torch.empty(config._out_features), requires_grad=config.trainable_bias
62
+ )
63
+ else:
64
+ self.bias = None
65
+ self.reset_parameters()
66
+
67
+ def reset_parameters(self, **kwargs):
68
+ # small init
69
+ nn.init.normal_(
70
+ self.weight.data, mean=0.0, std=sqrt(2 / 5 / self.weight.shape[-1])
71
+ )
72
+ if self.bias is not None:
73
+ nn.init.zeros_(self.bias.data)
74
+
75
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
76
+ shape = x.shape
77
+ x = x.view(*shape[:-1], self.config.num_heads, -1)
78
+ x = torch.einsum("...hd,hod->...ho", x, self.weight)
79
+ x = x.reshape(*shape[:-1], -1)
80
+ if self.bias is not None:
81
+ x = x + self.bias
82
+ return x
83
+
84
+ def extra_repr(self):
85
+ return (
86
+ f"in_features={self.config.in_features}, "
87
+ f"num_heads={self.config.num_heads}, "
88
+ f"expand_factor_up={self.config.expand_factor_up}, "
89
+ f"bias={self.config.bias}, "
90
+ f"trainable_weight={self.config.trainable_weight}, "
91
+ f"trainable_bias={self.config.trainable_bias}, "
92
+ )
protxlstm/xlstm/components/ln.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) NXAI GmbH and its affiliates 2024
2
+ # Maximilian Beck, Korbinian Pöppel
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+
7
+
8
+ class LayerNorm(nn.Module):
9
+ """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False."""
10
+
11
+ def __init__(
12
+ self,
13
+ ndim: int = -1,
14
+ weight: bool = True,
15
+ bias: bool = False,
16
+ eps: float = 1e-5,
17
+ residual_weight: bool = True,
18
+ ):
19
+ super().__init__()
20
+ self.weight = nn.Parameter(torch.zeros(ndim)) if weight else None
21
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
22
+ self.eps = eps
23
+ self.residual_weight = residual_weight
24
+ self.ndim = ndim
25
+ self.reset_parameters()
26
+
27
+ @property
28
+ def weight_proxy(self) -> torch.Tensor:
29
+ if self.weight is None:
30
+ return None
31
+ if self.residual_weight:
32
+ return 1.0 + self.weight
33
+ else:
34
+ return self.weight
35
+
36
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
37
+ return F.layer_norm(
38
+ input, normalized_shape=(self.ndim,), weight=self.weight_proxy, bias=self.bias, eps=self.eps
39
+ )
40
+
41
+ def reset_parameters(self):
42
+ if self.weight_proxy is not None:
43
+ if self.residual_weight:
44
+ nn.init.zeros_(self.weight)
45
+ else:
46
+ nn.init.ones_(self.weight)
47
+ if self.bias is not None:
48
+ nn.init.zeros_(self.bias)
49
+
50
+
51
+ class MultiHeadLayerNorm(LayerNorm):
52
+
53
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
54
+ assert input.dim() == 4, "Input must be 4D tensor (B, NH, S, DH)"
55
+ B, NH, S, DH = input.shape
56
+
57
+ gn_in_1 = input.transpose(1, 2) # (B, S, NH, DH)
58
+ gn_in_2 = gn_in_1.reshape(B * S, NH * DH) # (B * S, NH * DH)
59
+ out = F.group_norm(
60
+ gn_in_2,
61
+ num_groups=NH,
62
+ weight=self.weight_proxy,
63
+ bias=self.bias,
64
+ eps=self.eps,
65
+ )
66
+ # (B * S), (NH * DH) -> (B, S, NH, DH) -> (B, NH, S, DH)
67
+ out = out.view(B, S, NH, DH).transpose(1, 2)
68
+ return out
protxlstm/xlstm/components/rotary_position.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Tuple
3
+ import math
4
+
5
+
6
+ def compute_freqs_cis(t: torch.Tensor, head_dim: int, theta: float = 10_000.0):
7
+ freqs = theta ** (-torch.arange(0, head_dim, 2).float() / head_dim)
8
+ freqs = t.unsqueeze(-1) * freqs.to(t.device) # type: ignore
9
+ freqs_cos = torch.cos(freqs) # real part
10
+ freqs_sin = torch.sin(freqs) # imaginary part
11
+ return freqs_cos, freqs_sin
12
+
13
+
14
+ def apply_rotary_emb(
15
+ xq: torch.Tensor,
16
+ xk: torch.Tensor,
17
+ freqs_cos: torch.Tensor,
18
+ freqs_sin: torch.Tensor
19
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
20
+
21
+ # reshape xq and xk to match the complex representation
22
+ xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
23
+ xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
24
+
25
+ # apply rotation using real numbers
26
+ xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
27
+ xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
28
+ xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
29
+ xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
30
+
31
+ # flatten last two dimensions
32
+ xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(-2)
33
+ xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(-2)
34
+
35
+ return xq_out.type_as(xq), xk_out.type_as(xk)