Elias Buerger commited on
Commit
28f312f
·
1 Parent(s): e4995f0
Dockerfile DELETED
@@ -1,34 +0,0 @@
1
- FROM continuumio/anaconda3:main
2
-
3
- WORKDIR /code
4
- COPY ./environment.yml /code/environment.yml
5
-
6
- # Create the environment using the environment.yml file
7
- RUN conda env create -f /code/environment.yml
8
-
9
- # Set up a new user named "user" with user ID 1000
10
- RUN useradd -m -u 1000 user
11
- # Switch to the "user" user
12
- USER user
13
- # Set home to the user's home directory
14
- ENV HOME=/home/user \
15
- PYTHONPATH=$HOME/app \
16
- PYTHONUNBUFFERED=1 \
17
- GRADIO_ALLOW_FLAGGING=never \
18
- GRADIO_NUM_PORTS=1 \
19
- GRADIO_SERVER_NAME=0.0.0.0 \
20
- GRADIO_THEME=huggingface \
21
- SYSTEM=spaces
22
-
23
- # Set the working directory to the user's home directory
24
- WORKDIR $HOME/app
25
-
26
- # Copy the current directory contents into the container at $HOME/app setting the owner to the user
27
- COPY --chown=user . $HOME/app
28
-
29
- # cgjs, u+x
30
- RUN chmod u+x $HOME/app/run.sh
31
- RUN chmod -R 777 $HOME/
32
-
33
-
34
- CMD ["./run.sh"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -111,11 +111,13 @@ if __name__ == "__main__":
111
  placeholder=DEFAULT_SEQUENCE,
112
  )
113
  st.session_state.context_sequences = context_sequence_str.split(",") + [st.session_state.target_sequence]
 
114
  elif context_type == 'Use MSA file':
115
  msa_file = st.file_uploader("Choose MSA file")
116
  st.session_state.num_context_sequences = st.number_input("How many of these sequences should be used?", min_value=0, step=1, value=25)
117
  else:
118
  st.session_state.context_sequences = [st.session_state.target_sequence]
 
119
 
120
  if st.session_state.target_sequence != "":
121
  with st.container():
 
111
  placeholder=DEFAULT_SEQUENCE,
112
  )
113
  st.session_state.context_sequences = context_sequence_str.split(",") + [st.session_state.target_sequence]
114
+ msa_file = None
115
  elif context_type == 'Use MSA file':
116
  msa_file = st.file_uploader("Choose MSA file")
117
  st.session_state.num_context_sequences = st.number_input("How many of these sequences should be used?", min_value=0, step=1, value=25)
118
  else:
119
  st.session_state.context_sequences = [st.session_state.target_sequence]
120
+ msa_file = None
121
 
122
  if st.session_state.target_sequence != "":
123
  with st.container():
environment.yml DELETED
@@ -1,39 +0,0 @@
1
- name: prot_xlstm_app
2
- channels:
3
- - pytorch
4
- - nvidia
5
- - conda-forge
6
- - defaults
7
- dependencies:
8
- - cuda=12.1
9
- - cuda-nvcc=12.1
10
- - gxx_linux-64=11.2.0
11
- - python=3.11
12
- - pip
13
- - pytorch=2.2.0
14
- - pytorch-cuda=12.1
15
- - cmake
16
- - ninja
17
- - pip:
18
- - accelerate>=0.26.0
19
- - biopython #==1.83
20
- - bottleneck #==1.4.2
21
- - dacite #==1.8.1
22
- - ipykernel #==6.29.3
23
- - mamba_ssm==1.2.0
24
- - matplotlib #==3.8.4
25
- - numpy<2.0 #==1.26.4
26
- - omegaconf #==2.3.0
27
- - pandas #==2.2.2
28
- - pyhmmer #==0.10.15
29
- - rich #==13.7.1
30
- - scipy #==1.13.0
31
- - seaborn #==0.13.2
32
- - torchmetrics #==1.2.1
33
- - tqdm #==4.66.4
34
- - transformers==4.44.2
35
- - tueplots #==0.0.17
36
- - wandb #==0.17.0
37
- - streamlit #==1.43.2
38
-
39
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
protxlstm/applications/generation_utils/create_sequence_df.py DELETED
@@ -1,85 +0,0 @@
1
- import numpy as np
2
- import pickle
3
- import pandas as pd
4
-
5
- from protxlstm.dataloaders import ProteinMemmapDataset
6
- from protxlstm.utils import decode_sequence, reorder_masked_sequence
7
-
8
-
9
- def create_sequence_df(model_name, family_idx, parameters_list=None, num_sequences = 100, data_dir="./data/"):
10
-
11
- #load dataset
12
- dataset = ProteinMemmapDataset(
13
- msa_memmap_path=f"{data_dir}open_protein_set_memmap.dat",
14
- msa_memmap_meta_path=f"{data_dir}open_protein_set_memmap_indices.csv",
15
- subset_path=f"{data_dir}/cluster_testing_set.txt",
16
- sample=False,
17
- max_msa_len=-1,
18
- reverse=False,
19
- seed=0,
20
- troubleshoot=False,
21
- fim_strategy="multiple_span",
22
- always_mask=False,
23
- max_position_embeddings=2048,
24
- max_seq_position_embeddings=512,
25
- add_position_ids="1d",
26
- mask_fraction=0.2,
27
- max_patches=5,
28
- )
29
-
30
- family_id = list(dataset.dataset_meta["msa_id"])[family_idx]
31
-
32
- if model_name == "natural":
33
-
34
- data = dataset[family_idx]
35
- sequence_df = pd.DataFrame(columns=["family", "family_id", "sequence", "sequence_length"])
36
- tokens = data["input_ids"][None,:]
37
- all_context = decode_sequence(tokens[0].cpu().numpy())
38
- list_sequences_msa = [reorder_masked_sequence(elem+"<cls>") for elem in all_context.split("<cls>")[1:-1]]
39
-
40
- rd_idxs = np.random.choice(len(list_sequences_msa), num_sequences, replace=False)
41
- natural_sequences = [seq for i, seq in enumerate(list_sequences_msa) if i in rd_idxs]
42
-
43
- df_dict = {"family": [family_idx]*len(natural_sequences),
44
- "family_id": [family_id]*len(natural_sequences),
45
- "sequence": natural_sequences,
46
- "sequence_length": [len(seq) for seq in natural_sequences]}
47
-
48
- sequence_df = pd.concat([sequence_df, pd.DataFrame(df_dict)], ignore_index = True)
49
-
50
- else:
51
-
52
- sequence_df = pd.DataFrame(columns=["family", "family_id", "n_seqs_ctx", "temperature", "top_k", "top_p", "original_sequence", "sequence", "sequence_length", "perplexity"])
53
-
54
- if parameters_list is None:
55
- parameters_list = [(10,1.,10,1.), (10,1.,15,1.), (10,1.,10,0.95), (10,0.9,10,0.95), (10,0.8,10,0.9),
56
- (100,1.,10,1.), (100,1.,15,1.), (100,1.,10,0.95), (100,0.9,10,0.95), (100,0.8,10,0.9),
57
- (500,1.,10,1.), (500,1.,15,1.), (500,1.,10,0.95), (500,0.9,10,0.95), (500,0.8,10,0.9),
58
- (1000,1.,10,1.), (1000,1.,15,1.), (1000,1.,10,0.95), (1000,0.9,10,0.95), (1000,0.8,10,0.9),
59
- (-1,1.,10,1.), (-1,1.,15,1.), (-1,1.,10,0.95), (-1,0.9,10,0.95), (-1,0.8,10,0.9)]
60
-
61
- for param in parameters_list:
62
- n_seqs_ctx, temperature, top_k, top_p = param
63
-
64
- with open(f"evaluation/generation/generated_sequences/{model_name}/{family_idx}_{param}_{num_sequences}", "rb") as f:
65
- gen_seqs = pickle.load(f)
66
-
67
- original_sequences = list(gen_seqs[family_idx][param].keys())
68
- reordered_sequences = [reorder_masked_sequence(seq) for seq in original_sequences]
69
- perplexities = [gen_seqs[family_idx][param][seq]["perplexity"] for seq in original_sequences]
70
- df_dict = {"family": [family_idx]*len(original_sequences),
71
- "family_id": [family_id]*len(original_sequences),
72
- "n_seqs_ctx": [n_seqs_ctx]*len(original_sequences),
73
- "temperature": [temperature]*len(original_sequences),
74
- "top_k": [top_k]*len(original_sequences),
75
- "top_p": [top_p]*len(original_sequences),
76
- "original_sequence": original_sequences,
77
- "sequence": reordered_sequences,
78
- "sequence_length": [len(seq) for seq in reordered_sequences],
79
- "perplexity": perplexities
80
- }
81
-
82
- sequence_df = pd.concat([sequence_df, pd.DataFrame(df_dict)], ignore_index = True)
83
-
84
- return sequence_df
85
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
protxlstm/applications/generation_utils/score_hamming.py DELETED
@@ -1,80 +0,0 @@
1
- import numpy as np
2
- from tqdm import tqdm
3
- import pandas as pd
4
- from Bio import Align
5
-
6
- from protxlstm.dataloaders import ProteinMemmapDataset
7
- from protxlstm.utils import decode_sequence, reorder_masked_sequence
8
-
9
-
10
- aligner = Align.PairwiseAligner()
11
- aligner.mode = 'global'
12
- aligner.match_score = 1
13
- aligner.mismatch_score = -1
14
- aligner.open_gap_score = -1
15
- aligner.extend_gap_score = -1
16
-
17
- def align_sequences(ref_seq, query_seq, print_alignments=False):
18
- def hamming_str(s1,s2):
19
- assert len(s1) == len(s2)
20
- return sum(np.array(list(s1)) != np.array(list(s2)))/len(s1)
21
- alignments = aligner.align(ref_seq, query_seq)
22
- if print_alignments:
23
- print("Score = %.1f:" % alignments[0].score)
24
- print(alignments[0])
25
- return hamming_str(alignments[0][0], alignments[0][1]), alignments[0][0], alignments[0][1]
26
-
27
-
28
- def score_hamming(sequence_df, family_idx, data_dir = f"./data/"):
29
-
30
- assert len(set(list(sequence_df["family"]))) == 1 and sequence_df["family"].iloc[0] == family_idx
31
-
32
- #load dataset
33
- dataset = ProteinMemmapDataset(
34
- msa_memmap_path=f"{data_dir}open_protein_set_memmap.dat",
35
- msa_memmap_meta_path=f"{data_dir}open_protein_set_memmap_indices.csv",
36
- subset_path=f"{data_dir}/cluster_testing_set.txt",
37
- sample=False,
38
- max_msa_len=-1,
39
- reverse=False,
40
- seed=0,
41
- troubleshoot=False,
42
- fim_strategy="multiple_span",
43
- always_mask=False,
44
- max_position_embeddings=2048,
45
- max_seq_position_embeddings=512,
46
- add_position_ids="1d",
47
- mask_fraction=0.2,
48
- max_patches=5,
49
- )
50
-
51
- # Select a sample of the dataset to be the input
52
- data = dataset[family_idx]
53
- tokens = data["input_ids"][None,:]
54
- all_context = decode_sequence(tokens[0].cpu().numpy())
55
- list_sequences_msa = [reorder_masked_sequence(elem+"<cls>") for elem in all_context.split("<cls>")[1:-1]]
56
-
57
- # sequence_df["hamming"] = pd.Series(dtype=object)
58
- sequence_df["min_hamming"] = pd.Series()
59
- sequence_df["median_hamming"] = pd.Series()
60
- sequence_df["mean_hamming"] = pd.Series()
61
- sequence_df["std_hamming"] = pd.Series()
62
-
63
- for seq in tqdm(list(sequence_df["sequence"])):
64
-
65
- all_hamming = []
66
- for ctx_seq in list_sequences_msa:
67
- if ctx_seq == seq:
68
- continue
69
- else:
70
- hamming, _, _ = align_sequences(ctx_seq, seq , print_alignments=False)
71
- all_hamming.append(hamming)
72
-
73
- # sequence_df.loc[sequence_df["sequence"] == seq, "hamming"] = [all_hamming]
74
- sequence_df.loc[sequence_df["sequence"] == seq, "min_hamming"] = np.min(all_hamming)
75
- sequence_df.loc[sequence_df["sequence"] == seq, "median_hamming"] = np.median(all_hamming)
76
- sequence_df.loc[sequence_df["sequence"] == seq, "mean_hamming"] = np.mean(all_hamming)
77
- sequence_df.loc[sequence_df["sequence"] == seq, "std_hamming"] = np.std(all_hamming)
78
-
79
- return sequence_df
80
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
protxlstm/applications/generation_utils/score_hmmer.py DELETED
@@ -1,102 +0,0 @@
1
- import string
2
- from Bio import SeqIO
3
- import pyhmmer
4
- from tqdm import tqdm
5
-
6
- alphabet = pyhmmer.easel.Alphabet.amino()
7
-
8
- # This is an efficient way to delete lowercase characters and insertion characters from a string
9
- deletekeys = dict.fromkeys(string.ascii_lowercase)
10
- deletekeys["."] = None
11
- deletekeys["*"] = None
12
- translation = str.maketrans(deletekeys)
13
-
14
- def remove_insertions(sequence: str) -> str:
15
- """ Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """
16
- return sequence.translate(translation)
17
-
18
- def read_msa(filename: str):
19
- """ Reads the sequences from an MSA file, automatically removes insertions."""
20
- return [(record.description, remove_insertions(str(record.seq))) for record in SeqIO.parse(filename, "fasta")]
21
-
22
- def read_msa_unaligned(filename: str):
23
- """ Reads the sequences from an MSA file, removes only . - and * characters."""
24
- return [(record.description, str(record.seq).replace(".","").replace("-","").replace("*","").upper()) for record in SeqIO.parse(filename, "fasta")]
25
-
26
- def check_msa(msa):
27
- """ Checks if there are any repeated sequences in the MSA"""
28
- seqs = set()
29
- for el in msa:
30
- seqs.add(el[1])
31
- assert len(seqs) == len(msa), "There are repeated sequences in the MSA"
32
-
33
- def make_hmm_from_a3m_msa(msa_filepath, hmm_filename=None):
34
- # Load MSA from a3m
35
- msa_tup = read_msa(msa_filepath)
36
- # check_msa(msa_tup)
37
- # Create digitized MSA block
38
- all_seqs = [pyhmmer.easel.TextSequence(name=str(i).encode("utf-8"), sequence=seq) for i, (idz, seq) in enumerate(msa_tup)]
39
- msa = pyhmmer.easel.TextMSA(name=b"msa", sequences=all_seqs)
40
- msa = msa.digitize(alphabet)
41
- # Fit HMM
42
- builder = pyhmmer.plan7.Builder(alphabet)
43
- background = pyhmmer.plan7.Background(alphabet)
44
- hmm, _, _ = builder.build_msa(msa, background)
45
- if hmm_filename is not None:
46
- with open(f"{hmm_filename}.hmm", "wb") as output_file:
47
- hmm.write(output_file)
48
- return hmm
49
-
50
- def align_and_score_sequences_in_a3m_with_hmm(hmm, sequences_path=None, sequences_list=None):
51
- if sequences_list is not None:
52
- msa = sequences_list
53
- all_seqs = [pyhmmer.easel.TextSequence(name=str(i).encode("utf-8"), sequence=seq) for i, seq in enumerate(sequences_list)]
54
- elif sequences_path is not None:
55
- # Load sequences from a3m
56
- msa = read_msa_unaligned(sequences_path)
57
- all_seqs = [pyhmmer.easel.TextSequence(name=str(i).encode("utf-8"), sequence=seq) for i, (idz, seq) in enumerate(msa)]
58
- else:
59
- raise NotImplementedError("Missing sequences to align/score")
60
- # Create digitized Sequence block
61
- seq_block = pyhmmer.easel.TextSequenceBlock(all_seqs)
62
- seq_block = seq_block.digitize(alphabet)
63
- # Get all hits from the hmm
64
- background = pyhmmer.plan7.Background(alphabet)
65
- pipeline = pyhmmer.plan7.Pipeline(alphabet, background=background, bias_filter=False, F1=1.0, F2=1.0, F3=1.0)
66
- hits = pipeline.search_hmm(hmm, seq_block)
67
- if len(hits) != len(msa):
68
- print(f"Number of hits: {len(hits)} is different from the number of sequences in the MSA: {len(msa)}")
69
- # Extract hits
70
- all_hits = {}
71
- for hit in hits:
72
- idz, score, evalue = hit.name, hit.score, hit.evalue
73
- i = int(idz.decode("utf-8"))
74
- seq = msa[i][1] if sequences_path is not None else sequences_list[i]
75
- all_hits[seq] = {"score": score, "evalue": evalue}
76
- return all_hits
77
-
78
-
79
- def score_hmmer(sequence_df, family_idx, data_dir = f"./data/"):
80
-
81
- assert len(set(list(sequence_df["family"]))) == 1 and sequence_df["family"].iloc[0] == family_idx
82
-
83
- family_id = sequence_df["family_id"].iloc[0]
84
- msa_filepath = f"{data_dir}/a3m_files/{family_id}/a3m/uniclust30.a3m"
85
- try:
86
- hmm = make_hmm_from_a3m_msa(msa_filepath)
87
- except:
88
- raise Exception(f"Missing MSA of family {family_id}")
89
-
90
- # align sequences
91
- sequences = list(sequence_df["sequence"])
92
- scores = align_and_score_sequences_in_a3m_with_hmm(hmm, sequences_list=sequences)
93
-
94
- # save the scores associated to each sequence in the main df in the columns "score" and "evalue"
95
- for seq in tqdm(sequences):
96
- sequence_df.loc[sequence_df["sequence"] == seq, "score_gen"] = scores[seq]["score"] if seq in scores.keys() else 0
97
- sequence_df.loc[sequence_df["sequence"] == seq, "evalue_gen"] = scores[seq]["evalue"] if seq in scores.keys() else 1
98
-
99
- return sequence_df
100
-
101
-
102
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
protxlstm/applications/generation_utils/score_structure.py DELETED
@@ -1,55 +0,0 @@
1
- from Bio.PDB import PDBParser
2
- import torch
3
- from tqdm import tqdm
4
- from transformers import EsmForProteinFolding
5
-
6
- from protxlstm.utils import MASK_TO_ID
7
-
8
-
9
- pdb_parser = PDBParser()
10
-
11
-
12
- def compute_structure(seq, model):
13
- def keep_sequence(seq, l):
14
- if len(seq) > l:
15
- return False
16
- for mm in list(MASK_TO_ID.keys())+["<eos>", "<pad>", "<unk>", "<mask>", "<cls>", "<null_1>", "." , "-"]:
17
- if mm in seq:
18
- return False
19
- return True
20
- keep = keep_sequence(seq, l=750)
21
- if keep:
22
- with torch.no_grad():
23
- output = model.infer([seq])
24
- # pdb = model.output_to_pdb(output)
25
- ptm = output["ptm"].item()
26
- pae = output["predicted_aligned_error"].cpu().numpy()
27
- mean_plddt = ((output["plddt"] * output["atom37_atom_exists"]).sum(dim=(1, 2)) / output["atom37_atom_exists"].sum(dim=(1, 2))).item()
28
- pos_plddt = ((output["plddt"] * output["atom37_atom_exists"]).sum(dim=(2,)) / output["atom37_atom_exists"].sum(dim=(2,))).cpu().numpy()
29
- else:
30
- print(f"Sequence is invalid.")
31
- ptm, pae, mean_plddt, pos_plddt = 0, 0 ,0 , 0
32
- return ptm, pae, mean_plddt, pos_plddt
33
-
34
-
35
- def score_structure(sequence_df, family_idx):
36
-
37
- assert len(set(list(sequence_df["family"]))) == 1 and sequence_df["family"].iloc[0] == family_idx
38
-
39
- device="cuda:0"
40
-
41
- # Import the folding model
42
- model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True)
43
-
44
- model = model.cuda(device)
45
- model.esm = model.esm.half()
46
- torch.backends.cuda.matmul.allow_tf32 = True
47
-
48
- sequences = list(sequence_df["sequence"])
49
- for seq in tqdm(sequences):
50
-
51
- ptm, pae, mean_plddt, pos_plddt = compute_structure(seq, model)
52
- sequence_df.loc[sequence_df["sequence"] == seq, "ptm"] = ptm
53
- sequence_df.loc[sequence_df["sequence"] == seq, "mean_plddt"] = mean_plddt
54
-
55
- return sequence_df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
protxlstm/applications/sample_sequences.py DELETED
@@ -1,200 +0,0 @@
1
- import torch
2
- from tqdm import tqdm
3
- import pickle
4
- import os
5
- import argparse
6
- import json
7
-
8
- from protxlstm.dataloaders import ProteinMemmapDataset
9
- from protxlstm.generation import generate_sequence
10
- from protxlstm.utils import (
11
- AA_TO_ID,
12
- load_model,
13
- )
14
- from protxlstm.models.xlstm import xLSTMLMHeadModel
15
- from protxlstm.models.mamba import MambaLMHeadModelwithPosids
16
-
17
-
18
- def sample_sequences(dataset,
19
- model,
20
- family_idx,
21
- params,
22
- n_samples_per_family,
23
- max_length=1000,
24
- chunk_chunk_size=2**15,
25
- save_path=None,
26
- device="cuda:0"):
27
- """
28
- Function to sample sequences from the model. Given a dataset, a list of families (their indexes in the dataset)
29
- and a set of generating parameters, it generates `n_samples_per_family` sequences for each family and each parameter set.
30
- The function returns a dictionary with the following structure:
31
- gen_seqs = {family_idx: {parameters: {sequence: perplexity}}}
32
- The parameters are in a list of tuples with the following structure:
33
- parameters_list = [(nr_seqs_ctx, temperature, top_k, top_p)]
34
- """
35
- gen_seqs = {}
36
- gen_seqs[family_idx] = {}
37
- gen_seqs[family_idx][params] = {}
38
- print(f"Sampling sequences for family {family_idx} and parameters {params}.")
39
-
40
- n_seqs_ctx , temperature, top_k, top_p = params
41
- for _ in tqdm(range(n_samples_per_family)):
42
- # Sample the dataset to get the input
43
- data = dataset[family_idx]
44
- tokens = data["input_ids"][None,:].to(device)
45
- pos_ids = data["position_ids"][None,:].to(device)
46
-
47
- start_seqs = torch.argwhere(tokens[0]==0)[:,0].cpu().numpy()
48
-
49
- n_seqs_ctx = len(start_seqs) if len(start_seqs) < n_seqs_ctx else n_seqs_ctx
50
- L = start_seqs[n_seqs_ctx]+1
51
- context_tokens = tokens[:,:L]
52
- context_pos_ids = pos_ids[:,:L]
53
- is_fim={}
54
-
55
- # Generate the new sequence
56
- output = generate_sequence(model,
57
- context_tokens,
58
- position_ids=context_pos_ids,
59
- is_fim=is_fim,
60
- max_length=(L+max_length),
61
- temperature=temperature,
62
- top_k=top_k,
63
- top_p=top_p,
64
- return_dict_in_generate=True,
65
- output_scores=True,
66
- eos_token_id=torch.tensor([AA_TO_ID["<cls>"]]).to(device),
67
- chunk_chunk_size=chunk_chunk_size,
68
- device=device)
69
-
70
- # Get the perplexity of the generated sequence
71
- output_seq = output["generated"]
72
- loss = torch.nn.functional.cross_entropy(torch.from_numpy(output["scores"]).permute(0, 2, 1),
73
- torch.from_numpy(output["generated_tokens"][0][None,:]))
74
-
75
- # save only sequences with length < max_length
76
- if len(output_seq[0]) < max_length:
77
-
78
- gen_seqs[family_idx][params][output_seq[0]] = {"perplexity": torch.exp(loss).item()}
79
-
80
- if save_path is not None:
81
- if not os.path.exists("evaluation/generation/generated_sequences"):
82
- os.mkdir("evaluation/generation/generated_sequences")
83
- if not os.path.exists(save_path):
84
- os.mkdir(save_path)
85
- with open(f'{save_path}/{family_idx}_{params}_{n_samples_per_family}', "wb") as f:
86
- pickle.dump(gen_seqs, f)
87
- print(f"Sequences saved for family {family_idx} and parameters {params}")
88
-
89
- return gen_seqs
90
-
91
- def generate_sequences(model_name,
92
- checkpoint,
93
- family_idxs=[],
94
- parameters_list=[],
95
- n_samples_per_family = 100,
96
- chunk_size=1024,
97
- chunk_chunk_size=2**15,
98
- data_dir="data/",
99
- device="cuda:0"
100
- ):
101
-
102
- # Load the test dataset
103
- fim_strategy = "multiple_span"
104
- mask_fraction = 0.2
105
-
106
- dataset = ProteinMemmapDataset(
107
- msa_memmap_path=f"{data_dir}open_protein_set_memmap.dat",
108
- msa_memmap_meta_path=f"{data_dir}open_protein_set_memmap_indices.csv",
109
- subset_path=f"{data_dir}cluster_testing_set.txt",
110
- sample=False,
111
- max_msa_len=-1,
112
- reverse=False,
113
- seed=0,
114
- troubleshoot=False,
115
- fim_strategy=fim_strategy,
116
- always_mask=False,
117
- max_position_embeddings=2048,
118
- max_seq_position_embeddings=512,
119
- add_position_ids="1d",
120
- mask_fraction=mask_fraction
121
- )
122
-
123
- if model_name == "xlstm":
124
- model_class = xLSTMLMHeadModel
125
- elif model_name == "mamba":
126
- model_class = MambaLMHeadModelwithPosids
127
-
128
- save_path = f"evaluation/generation/generated_sequences/{checkpoint.split('/')[-1]}"
129
-
130
- if model_name == "xlstm":
131
- config_update_kwargs = {
132
- "mlstm_backend": "chunkwise_variable",
133
- "mlstm_chunksize": chunk_size,
134
- "mlstm_return_last_state": True
135
- }
136
- else:
137
- config_update_kwargs = {}
138
-
139
-
140
- #load the model
141
- model = load_model(checkpoint,
142
- model_class=model_class,
143
- device=device,
144
- dtype=torch.bfloat16,
145
- **config_update_kwargs,
146
- )
147
- model = model.eval()
148
- print("Model loaded.")
149
-
150
- for family_idx in family_idxs:
151
- for params in parameters_list:
152
- params = tuple(params)
153
- if not os.path.exists(f'{save_path}/{family_idx}_{params}_{n_samples_per_family}'):
154
- gen_seqs = sample_sequences(
155
- dataset=dataset,
156
- model=model,
157
- family_idx=family_idx,
158
- params=params,
159
- n_samples_per_family=n_samples_per_family,
160
- chunk_chunk_size=chunk_chunk_size,
161
- save_path=save_path,
162
- device=device)
163
-
164
- print(f"Sampled {len(gen_seqs[family_idx][params])} valid sequences.")
165
- else:
166
- print(f"Sequences for family {family_idx} and parameters {params} already exist.")
167
-
168
-
169
- if __name__ == "__main__":
170
-
171
- parser = argparse.ArgumentParser(
172
- description="Generate sequences."
173
- )
174
- parser.add_argument("--model_name", type=str, help="Either 'xlstm' or 'mamba'.")
175
- parser.add_argument("--checkpoint", type=str, help="Path to model checkpoint.")
176
- parser.add_argument("--family_idxs", type=str, help="List of family indices.")
177
- parser.add_argument("--parameters_list", type=str, help="List of sampling parameters.")
178
- parser.add_argument("--n_samples_per_family", type=int, default=100, help="Number of sequences to sample per family and parameter set.")
179
- parser.add_argument("--chunk_size", type=int, default=1024, help="Chunk size for xLSTM context encoding.")
180
- parser.add_argument("--chunk_chunk_size", type=int, default=2*15, help="Length of context sequence part processed at once.")
181
- parser.add_argument("--data_dir", type=str, default="data/", help="Path to dataset.")
182
- parser.add_argument("--device", type=str, default="cuda:0", help="Device.")
183
-
184
- args = parser.parse_args()
185
-
186
- family_idxs = json.loads(args.family_idxs)
187
- parameters_list = json.loads(args.parameters_list)
188
-
189
- # Run sequence generation
190
- generate_sequences(
191
- model_name=args.model_name,
192
- checkpoint=args.checkpoint,
193
- family_idxs=family_idxs,
194
- parameters_list=parameters_list,
195
- n_samples_per_family=args.n_samples_per_family,
196
- chunk_size=args.chunk_size,
197
- chunk_chunk_size=args.chunk_chunk_size,
198
- data_dir=args.data_dir,
199
- device=args.device,
200
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
protxlstm/applications/score_sequences.py DELETED
@@ -1,58 +0,0 @@
1
- import argparse
2
- import os
3
- import pickle
4
-
5
- from generation_utils.create_sequence_df import create_sequence_df
6
- from generation_utils.score_hamming import score_hamming
7
- from generation_utils.score_hmmer import score_hmmer
8
- from generation_utils.score_structure import score_structure
9
-
10
-
11
- def score_sequences(model_name,
12
- family_idx,
13
- num_sequences = 100,
14
- data_dir = "data/"):
15
-
16
- if os.path.isfile(f"evaluation/generation/evaluations/{model_name}/sequence_df_{family_idx}"):
17
- with open(f"evaluation/generation/evaluations/{model_name}/sequence_df_{family_idx}", "rb") as f:
18
- sequence_df = pickle.load(f)
19
- else:
20
- sequence_df = create_sequence_df(model_name, family_idx, data_dir = data_dir, num_sequences = num_sequences)
21
- if not os.path.exists("evaluation/generation/evaluations/"):
22
- os.mkdir("evaluation/generation/evaluations/")
23
- if not os.path.exists(f"evaluation/generation/evaluations/{model_name}/"):
24
- os.mkdir(f"evaluation/generation/evaluations/{model_name}/")
25
- with open(f"evaluation/generation/evaluations/{model_name}/sequence_df_{family_idx}", "wb") as f:
26
- pickle.dump(sequence_df, f)
27
-
28
- if not "min_hamming" in sequence_df.columns:
29
- sequence_df = score_hamming(sequence_df, family_idx, data_dir)
30
- with open(f"evaluation/generation/evaluations/{model_name}/sequence_df_{family_idx}", "wb") as f:
31
- pickle.dump(sequence_df, f)
32
-
33
- if not "score_gen" in sequence_df.columns:
34
- sequence_df = score_hmmer(sequence_df, family_idx, data_dir)
35
- with open(f"evaluation/generation/evaluations/{model_name}/sequence_df_{family_idx}", "wb") as f:
36
- pickle.dump(sequence_df, f)
37
-
38
- if not "ptm" in sequence_df.columns:
39
- sequence_df = score_structure(sequence_df, family_idx)
40
- with open(f"evaluation/generation/evaluations/{model_name}/sequence_df_{family_idx}", "wb") as f:
41
- pickle.dump(sequence_df, f)
42
-
43
- return sequence_df
44
-
45
-
46
- if __name__ == "__main__":
47
-
48
- parser = argparse.ArgumentParser(
49
- description="Generate sequences."
50
- )
51
- parser.add_argument("--model_name", type=str, help="Either 'xlstm' or 'mamba'.")
52
- parser.add_argument("--family_idx", type=int, help="Family index.")
53
- parser.add_argument("--num_sequences", type=int, default=100, help="Number of sequences.")
54
- parser.add_argument("--data_dir", type=str, default="./data/", help="Path to dataset.")
55
-
56
- args = parser.parse_args()
57
-
58
- sequence_df = score_sequences(args.model_name, args.family_idx, args.num_sequences, args.data_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
protxlstm/data.py DELETED
@@ -1,60 +0,0 @@
1
- import csv
2
- import os
3
-
4
- import numpy as np
5
- from tqdm import tqdm
6
-
7
- from protxlstm.utils import load_sequences_from_msa_file, tokenizer
8
-
9
- def process_msa(msa_item):
10
- msa_name, msa_path = msa_item
11
- # Load an a3m file with all the context sequences
12
- msa = load_sequences_from_msa_file(msa_path)
13
- # Tokenize the sequences and concatenate them into a single array
14
- tokens = tokenizer(msa, concatenate=True)
15
- tokens = tokens.numpy()[0]
16
- return msa_name, tokens
17
-
18
- def main(data_dir, output_dir):
19
- msa_paths = {k: os.path.join(data_dir, k, 'a3m/uniclust30.a3m') for k in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, k))}
20
- msa_items = list(msa_paths.items())
21
-
22
- dataset_dictionary = {}
23
- total_length = 0
24
-
25
- # First pass: calculate total length of all concatenated arrays
26
- for item in tqdm(msa_items):
27
- try:
28
- k, v = process_msa(item)
29
- dataset_dictionary[k] = v
30
- total_length += len(v)
31
- except:
32
- print(f"Error processing {item}")
33
-
34
- # Initialize the memmap array with the calculated total length
35
- memmap_path = os.path.join(output_dir, 'open_protein_set_memmap.dat')
36
- concatenated_array = np.memmap(memmap_path, dtype='int8', mode='w+', shape=(total_length,))
37
-
38
- with open(f'{output_dir}/open_protein_set_memmap_indices.csv', 'w', newline='') as csvfile:
39
- csvwriter = csv.writer(csvfile)
40
-
41
- csvwriter.writerow(['msa_id', 'Start', 'End'])
42
-
43
- start_index = 0
44
- for key, array in dataset_dictionary.items():
45
- end_index = start_index + len(array) - 1
46
- concatenated_array[start_index:end_index + 1] = array # Write to memmap
47
- csvwriter.writerow([key, start_index, end_index])
48
- start_index = end_index + 1
49
-
50
- # Ensure the data is written to disk
51
- concatenated_array.flush()
52
-
53
-
54
- if __name__ == "__main__":
55
- data_dir = 'data/a3m_files'
56
- output_dir = 'data/'
57
- main(data_dir, output_dir)
58
-
59
-
60
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
protxlstm/dataloaders.py DELETED
@@ -1,249 +0,0 @@
1
- # Original code from ProtMamba under Apache License 2.0.
2
- #
3
- # Modifications made by Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen
4
- # - Uniclust30_Dataset renamed to ProteinMemmapDataset
5
- # - Dataset input file format changed for more efficient dataloading
6
- # - Option to use only a subset
7
- # - DataCollatorForUniclust30Dataset renamed to ProteinDataCollator
8
- # - Add sequence padding
9
-
10
- import numpy as np
11
- import pandas as pd
12
- import torch
13
- from torch.utils.data import DataLoader, Dataset
14
- from typing import Dict, Optional, Sequence
15
-
16
- from protxlstm.fim import MultipleSpanFIM, NoFIM, SingleSpanFIM
17
- from protxlstm.utils import AA_TO_ID
18
-
19
-
20
- # Make dataset
21
- class ProteinMemmapDataset(Dataset):
22
- """
23
- ProteinMemmapDataset is a PyTorch Dataset class for handling memory-mapped datasets of protein multiple sequence alignments (MSAs).
24
-
25
- This class imports MSA data stored in memmap format and associated metadata CSVs. It supports flexible
26
- data sampling strategies and inpainting methods for sequence manipulation and training purposes.
27
-
28
- Args:
29
- msa_memmap_path (str): Path to the memory-mapped file containing the MSA clusters.
30
- msa_memmap_meta_path (str): Path to the CSV file with metadata linking MSA Cluster IDs and indices in the memmap array.
31
- subset_path (str, optional): Path to a CSV file specifying a subset of cluster IDs to use.
32
- sample (bool, optional): If True, randomly samples sequences from each cluster; otherwise, loads all sequences and shuffles them.
33
- max_msa_len (int, optional): Maximum length of the MSA sequences to include. Defaults to -1 (no limit).
34
- reverse (bool, optional): If True, reverses sequences with a probability of 0.5 and moves the last token to the front.
35
- seed (int, optional): Random seed for reproducibility. Defaults to 42.
36
- troubleshoot (bool, optional): If True, prints debugging information. Defaults to False.
37
- fim_strategy (str, optional): Strategy for inpainting ("no-scramble", "one_span", or "multiple_span").
38
- max_patches (int, optional): Number of patches for inpainting. Used when fim_strategy is "multiple_span".
39
- mask_fraction (float, optional): Fraction of the patches to mask. Used when fim_strategy is "multiple_span".
40
- always_mask (bool, optional): If True, ensures masking is applied in the inpainting process.
41
- max_position_embeddings (int, optional): Maximum position embeddings. Defaults to 2048.
42
- max_seq_position_embeddings (int, optional): Maximum sequence position embeddings for 2D positional IDs. Defaults to 512.
43
- add_position_ids (str, optional): Type of position IDs to add ("none", "1d", or "2d"). Defaults to "1d".
44
- """
45
-
46
- _FIM = {"no-scramble": NoFIM, "one_span": SingleSpanFIM, "multiple_span": MultipleSpanFIM}
47
- _POSIDS = {"none", "1d", "2d"}
48
-
49
- def __init__(self,
50
- msa_memmap_path=None,
51
- msa_memmap_meta_path=None,
52
- subset_path=None,
53
- sample=False,
54
- max_msa_len=-1,
55
- reverse=False,
56
- seed=42,
57
- troubleshoot=False,
58
- fim_strategy="no-scramble",
59
- max_patches=5,
60
- mask_fraction=0.2,
61
- always_mask=False,
62
- max_position_embeddings=2048,
63
- max_seq_position_embeddings=512,
64
- add_position_ids="1d", ):
65
-
66
- np.random.seed(seed)
67
-
68
- if msa_memmap_path:
69
- self.dataset = np.memmap(msa_memmap_path, dtype=np.int8, mode='r')
70
- self.dataset_meta = pd.read_csv(msa_memmap_meta_path)
71
- if subset_path:
72
- subset_ids = pd.read_csv(subset_path, header=None, names=['ID'])['ID'].tolist()
73
- self.dataset_meta = self.dataset_meta[self.dataset_meta['msa_id'].isin(subset_ids)]
74
- else:
75
- self.dataset = None
76
-
77
- self.sample = sample
78
- self.max_msa_len = max_msa_len
79
- self.reverse = reverse
80
- self.fim_strategy = fim_strategy
81
- if fim_strategy in ProteinMemmapDataset._FIM:
82
- self.fim = ProteinMemmapDataset._FIM[fim_strategy](max_patches=max_patches,
83
- mask_fraction=mask_fraction,
84
- always_mask=always_mask,
85
- add_position_ids=add_position_ids != "none",
86
- troubleshoot=troubleshoot)
87
- else:
88
- raise ValueError(f'Fill in the middle stragy "{fim_strategy}" not recognized.')
89
-
90
- self.max_position_embeddings = max_position_embeddings
91
- self.max_seq_position_embeddings = max_seq_position_embeddings
92
- self.add_position_ids = add_position_ids
93
-
94
- self.troubleshoot = troubleshoot
95
-
96
- def __len__(self):
97
- # meta dataframe has one row for each MSA cluster
98
- return len(self.dataset_meta)
99
-
100
- def __getitem__(self, idx):
101
- # get all the sequences in the cluster
102
- sequences = self.get_sequences(idx)
103
- # get total number of sequences in the cluster and choose how many to sample
104
- orig_num_sequences = len(self.get_index_start_of_sequences(sequences))
105
- num_sequences = np.random.randint(1, orig_num_sequences + 1) if self.sample else orig_num_sequences
106
- # sample the sequences
107
- sequences, position_ids = self.sample_sequences(sequences, num_sequences)
108
- # with probability 0.5, reverse the sequences and move the last token to the front
109
- sequences, position_ids = self.reverse_sequences(sequences, position_ids) if (
110
- self.reverse and np.random.rand() > 0.5) else sequences, position_ids
111
- # limit the length of the MSA
112
- sequences = sequences[:self.max_msa_len] if self.max_msa_len > 0 else sequences
113
- if self.add_position_ids != "none":
114
- position_ids = position_ids[:self.max_msa_len] if self.max_msa_len > 0 else position_ids
115
- # convert to tensor
116
- sequences = torch.asarray(sequences, dtype=torch.int64)
117
- position_ids = torch.asarray(position_ids, dtype=torch.int64).clamp(0,
118
- self.max_position_embeddings - 1) if self.add_position_ids!="none" else None
119
-
120
- if self.troubleshoot:
121
- print(
122
- f"Cluster {idx} has {orig_num_sequences} sequences, of which {num_sequences} sampled now. Total MSA length: {len(sequences)}")
123
- if self.add_position_ids == "1d":
124
- return dict(input_ids=sequences, position_ids=position_ids, labels=sequences)
125
- if self.add_position_ids == "2d":
126
- seq_position_ids = (sequences == AA_TO_ID["<cls>"]).int().cumsum(-1).clamp(0,
127
- self.max_seq_position_embeddings - 1).contiguous()
128
- return dict(input_ids=sequences, position_ids=position_ids, seq_position_ids=seq_position_ids,
129
- labels=sequences)
130
- return dict(input_ids=sequences, labels=sequences)
131
-
132
- def get_msa_id(self, idx):
133
- """Get the MSA ID in the cluster with index `idx`."""
134
- cluster_meta = self.dataset_meta.iloc[idx]
135
- return cluster_meta.msa_id
136
-
137
- def get_idx_from_msa_id(self, msa_id):
138
- """Get `idx` with the MSA ID"""
139
- return self.dataset_meta[self.dataset_meta.msa_id == msa_id].index[0]
140
-
141
- def get_sequences(self, idx):
142
- """Get the sequences in the cluster with index `idx`."""
143
- cluster_meta = self.dataset_meta.iloc[idx]
144
- sequences = self.dataset[cluster_meta.Start : cluster_meta.End]
145
- return sequences
146
-
147
- def get_index_start_of_sequences(self, sequences):
148
- """Get the positions of the start of each sequence in the cluster."""
149
- return np.where(sequences == 0)[0]
150
-
151
- def reverse_sequences(self, sequence, position_ids=None):
152
- """Reverse the sequences and move the last token to the front."""
153
- sequence = sequence[::-1]
154
- if position_ids is not None:
155
- position_ids = position_ids[::-1]
156
- return np.concatenate([sequence[-1:], sequence[:-1]]), np.concatenate(
157
- [position_ids[-1:], position_ids[:-1]]) if position_ids is not None else None
158
-
159
- def sample_sequences(self, sequences, num_sequences, shuffle=True):
160
- """Sample `num_sequences` from the sequences in the cluster."""
161
- L = len(sequences)
162
- # get the indexes of the start of each sequence
163
- inds = self.get_index_start_of_sequences(sequences)
164
- # check that there are sequences in the cluster and that there are enough of them
165
- assert len(inds) > 0, "No sequences found in cluster."
166
- assert len(inds) >= num_sequences, "Not enough sequences in cluster."
167
- # sample n_sequences randomly from the sequences
168
- if shuffle:
169
- which_seqs = np.random.choice(np.arange(len(inds)), num_sequences, replace=False)
170
- else:
171
- which_seqs = np.arange(len(inds))[-num_sequences:]
172
- # get the tuples of start and end indexes of the sequences
173
- tuples = [(inds[i], inds[i + 1]) if i < len(inds) - 1 else (inds[i], L) for i in which_seqs]
174
- if self.troubleshoot:
175
- print(f"Sampled sequences: {tuples}")
176
- # concatenate the sequences
177
- sequences, position_ids = self.fim.apply(sequences, tuples)
178
- return sequences, position_ids
179
-
180
-
181
-
182
- def make_dataloader(dataset):
183
- """Basic function to make a dataloader.
184
- """
185
- dataloader = DataLoader(dataset)
186
- return dataloader
187
-
188
-
189
- class ProteinDataCollator(object):
190
- """
191
- Collate examples into a batch, and pad batch to a specified maximum sequence length,
192
- or to the longest sequence in the batch if max_sequence_length is None.
193
- """
194
- def __init__(self, max_sequence_length: Optional[int] = None):
195
- """
196
- Initialize the collator with an optional max_sequence_length.
197
-
198
- Args:
199
- max_sequence_length (Optional[int]): The maximum sequence length to pad/truncate to.
200
- If None, pad to the longest sequence in the batch.
201
- """
202
- self.max_sequence_length = max_sequence_length
203
-
204
- def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
205
-
206
- input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "input_ids"))
207
-
208
- longest_seq = max(len(seq) for seq in input_ids)
209
- if self.max_sequence_length is None:
210
- max_len = longest_seq
211
- else:
212
- max_len = self.max_sequence_length
213
-
214
- input_ids = self.pad_sequences(input_ids, max_len, padding_value=AA_TO_ID["<pad>"])
215
-
216
- labels = self.pad_sequences(labels, longest_seq, padding_value=AA_TO_ID["<pad>"])
217
- labels = self.pad_sequences(labels, max_len, padding_value=-100)
218
-
219
- return_dict = dict(
220
- input_ids=input_ids,
221
- labels=labels,
222
- attention_mask=input_ids.ne(AA_TO_ID["<pad>"])
223
- )
224
-
225
- if "position_ids" in instances[0]:
226
-
227
- position_ids = [instance["position_ids"] for instance in instances]
228
- position_ids = self.pad_sequences(position_ids, max_len, padding_value=0)
229
- return_dict["position_ids"] = position_ids
230
-
231
- if "seq_position_ids" in instances[0]:
232
- seq_position_ids = [instance["seq_position_ids"] for instance in instances]
233
- seq_position_ids = self.pad_sequences(seq_position_ids, max_len, padding_value=0)
234
- return_dict["seq_position_ids"] = seq_position_ids
235
-
236
- return return_dict
237
-
238
- def pad_sequences(self, seqs, max_length, padding_value):
239
- # truncate long sequences (redundant, already done in __getitem__, maybe safe to remove)
240
- seqs = [seq[:max_length] for seq in seqs]
241
-
242
- # pad to same length
243
- seqs = torch.nn.utils.rnn.pad_sequence(seqs, batch_first=True, padding_value=padding_value)
244
-
245
- # pad to max length
246
- padding = max_length - seqs.size(1)
247
- seqs = torch.nn.functional.pad(seqs, (0, padding), value=padding_value)
248
-
249
- return seqs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
protxlstm/fim.py DELETED
@@ -1,203 +0,0 @@
1
-
2
- # Original code from ProtMamba under Apache License 2.0.
3
-
4
- from protxlstm.utils import MASK_TO_ID, AA_TO_ID
5
- import numpy as np
6
-
7
- class AbstractFIM(object):
8
- def __init__(self,
9
- max_patches=5,
10
- mask_fraction=0.2,
11
- always_mask=False,
12
- mask_tokens=MASK_TO_ID,
13
- eos_token=AA_TO_ID["<eos>"],
14
- add_position_ids=False,
15
- troubleshoot=False):
16
- """
17
- This class is designed to concatenate sequences based on different scrambling strategies.
18
- It takes a list of sequences, tuples indicating the start and end indices of each sequence,
19
- an optional number of patches to sample, and a scrambling strategy as inputs.
20
- """
21
- self.troubleshoot = troubleshoot
22
- self.max_patches = max_patches
23
- self.mask_fraction = mask_fraction
24
- self.mask_tokens = mask_tokens
25
- assert len(
26
- self.mask_tokens) >= self.max_patches, "Number of mask tokens must be bigger than max number of patches."
27
- self.eos_token = eos_token
28
- self.add_position_ids = add_position_ids
29
- self.always_mask = always_mask
30
-
31
- def apply(self, sequences, tuples):
32
- """
33
- This function concatenates the sequences scrambling each one according to the scrambling strategy.
34
- """
35
- input_ids, position_ids = [], []
36
- for t in tuples:
37
- seq, pos = self.fim(sequences, t)
38
- input_ids.extend(seq)
39
- if self.add_position_ids:
40
- position_ids.extend(pos)
41
- if self.add_position_ids:
42
- return input_ids, position_ids
43
- return input_ids, None
44
-
45
- def fim(self, sequences, t):
46
- """
47
- This function concatenates the sequence's parts based on the scrambling strategy.
48
- """
49
- raise NotImplementedError
50
-
51
-
52
- class NoFIM(AbstractFIM):
53
- def __init__(self,
54
- max_patches=5,
55
- mask_fraction=0.2,
56
- always_mask=False,
57
- mask_tokens=MASK_TO_ID,
58
- eos_token=AA_TO_ID["<eos>"],
59
- add_position_ids=False,
60
- troubleshoot=False):
61
- super().__init__(max_patches, mask_fraction, always_mask, mask_tokens, eos_token, add_position_ids, troubleshoot)
62
-
63
- def fim(self, sequences, t):
64
- """
65
- This function keeps the sequence identical without any scrambling.
66
- """
67
- if self.add_position_ids:
68
- position_ids = np.arange(t[0], t[1]) - t[0]
69
- return sequences[t[0]:t[1]], position_ids
70
- return sequences[t[0]:t[1]], None
71
-
72
-
73
- class SingleSpanFIM(AbstractFIM):
74
-
75
- def __init__(self,
76
- max_patches=5,
77
- mask_fraction=0.2,
78
- always_mask=False,
79
- mask_tokens=MASK_TO_ID,
80
- eos_token=AA_TO_ID["<eos>"],
81
- add_position_ids=False,
82
- troubleshoot=False):
83
- super().__init__(max_patches, mask_fraction, always_mask, mask_tokens, eos_token, add_position_ids, troubleshoot)
84
-
85
- def fim(self, sequences, t):
86
- """
87
- This function creates and concatenates parts of the sequences based on the OpenAI scrambling strategy.
88
- It randomly selects two indices within the range of the given tuple,
89
- splits the sequence into three parts based on these indices, and then concatenates them with the
90
- masked patch at the end
91
- """
92
- new_tuple = tuple(np.sort(np.random.choice(np.arange(t[0] + 1, t[1]), 2, replace=False)))
93
- part1 = sequences[t[0]:new_tuple[0]]
94
- part2 = sequences[new_tuple[0]:new_tuple[1]]
95
- part3 = sequences[new_tuple[1]:t[1]]
96
- sequence = np.concatenate([part1, [self.mask_tokens["<mask-1>"]], part3, [self.mask_tokens["<mask-1>"]], part2])
97
- position_ids_sequence = None
98
- if self.add_position_ids:
99
- position_ids = np.arange(t[0], t[1]) - t[0]
100
- position_ids_part1 = position_ids[t[0]:new_tuple[0]]
101
- position_ids_part2 = position_ids[new_tuple[0]:new_tuple[1]]
102
- position_ids_part3 = position_ids[new_tuple[1]:t[1]]
103
- position_ids_sequence = np.concatenate(
104
- [position_ids_part1, [position_ids_part2[0]], position_ids_part3, [position_ids_part2[0]],
105
- position_ids_part2])
106
-
107
- return sequence, position_ids_sequence
108
-
109
-
110
- class MultipleSpanFIM(AbstractFIM):
111
- def __init__(self,
112
- max_patches=5,
113
- mask_fraction=0.2,
114
- always_mask=False,
115
- mask_tokens=MASK_TO_ID,
116
- eos_token=AA_TO_ID["<eos>"],
117
- add_position_ids=False,
118
- troubleshoot=False):
119
- super().__init__(max_patches, mask_fraction, always_mask, mask_tokens, eos_token, add_position_ids, troubleshoot)
120
-
121
- def fim(self, sequences, t):
122
- """
123
- This function creates and concatenates parts of the sequences based on the inpaint scrambling strategy.
124
- It randomly selects `2*num_patches` indices within the range of the given tuple,
125
- splits the sequence into unmasked and masked parts based on these indices, and then concatenates them.
126
- The number of patches is sampled from a poisson distribution with upper limit `self.max_patches` and average 1.
127
- The concatenation is done by joining all unmaksed parts (interleaved with mask tokens) and afterwards
128
- all masked parts (interleaved with mask tokens). At the end of the unmasked parts, a special token is added
129
- to indicate the end of the unmasked parts, and at the end of the masked parts, a special token is added
130
- to indicate the end of the masked parts.
131
- """
132
- # sample num_patches from a discrete poisson distribution with upper limit L
133
- def sample_lengths(start, end):
134
- """
135
- Sample a length uniformly from 1 to max_L*self.mask_fraction (must be bigger than 1).
136
- If the length is larger than max_L, return max_L.
137
- """
138
- max_L = end - start
139
- length = np.random.randint(1, max(int(max_L * self.mask_fraction), 2))
140
- return min(length, max_L)
141
-
142
- # sample num_patches from a discrete poisson distribution with upper limit max_patches
143
- num_patches = 1000
144
- while num_patches > self.max_patches:
145
- num_patches = np.random.poisson(1)
146
- if self.always_mask:
147
- num_patches = max(num_patches, 1)
148
- # sample num_patches starting points for the masked positions (+ final position)
149
- start_patches = list(np.sort(np.random.choice(np.arange(t[0] + 1, t[1]),
150
- num_patches,
151
- replace=False))) + [t[1]]
152
- # sample num_patches lengths of the patches
153
- len_patches = [sample_lengths(start_patches[i], start_patches[i + 1])
154
- for i in range(len(start_patches) - 1)]
155
- # create masked tuples with start and end indices of the patches
156
- masked_tuples = [(start_patches[i], start_patches[i] + len_patches[i]) for i in range(len(start_patches) - 1)]
157
- # split the sequences into unmasked and masked parts
158
- unmasked_sequence, masked_sequence, unmasked_position_ids, masked_position_ids = self.split_sequences(sequences,
159
- t,
160
- masked_tuples)
161
-
162
- if self.troubleshoot:
163
- print(f"For sequence in {t}: sampled {num_patches=}, {start_patches=}, {len_patches=}, {masked_tuples=}")
164
- # concatenate the unmasked and masked parts
165
- return unmasked_sequence + masked_sequence, unmasked_position_ids + masked_position_ids if self.add_position_ids else None
166
-
167
- def split_sequences(self, sequences, t, masked_tuples):
168
- """
169
- This function splits the sequences into unmasked and masked parts based on the given tuples.
170
- Args:
171
- t (tuple): The start and end index of each sequence.
172
- masked_tuples (list): A list of tuples specifying the indices for masked regions.
173
- Returns:
174
- unmasked_parts (list): The unmasked parts of the sequences interleaved with mask_tokens.
175
- masked_parts (list): The masked parts of the sequences interleaved with mask_tokens.
176
- """
177
- unmasked_parts, masked_parts = [], []
178
- unmasked_positions, masked_positions = [], []
179
- position_ids = None
180
- start, end = t
181
- if self.add_position_ids:
182
- position_ids = np.arange(start, end) - start
183
- for i, region in enumerate(masked_tuples):
184
- mask_token = self.mask_tokens[f"<mask-{i + 1}>"]
185
- unmasked_parts.extend(sequences[start:region[0]])
186
- unmasked_parts.append(mask_token)
187
- masked_parts.append(mask_token)
188
- masked_parts.extend(sequences[region[0]:region[1]])
189
- if self.add_position_ids:
190
- unmasked_positions.extend(position_ids[start-t[0]:region[0]-t[0]])
191
- unmasked_positions.append(position_ids[region[0]-t[0]])
192
- masked_positions.append(position_ids[region[0]-t[0]])
193
- masked_positions.extend(position_ids[region[0]-t[0]:region[1]-t[0]])
194
-
195
- start = region[1]
196
- unmasked_parts.extend(sequences[start:end])
197
- if self.add_position_ids:
198
- unmasked_positions.extend(position_ids[start-t[0]:end-t[0]])
199
- if len(masked_tuples) > 0:
200
- unmasked_parts.append(self.eos_token)
201
- if self.add_position_ids:
202
- unmasked_positions.append(0)
203
- return unmasked_parts, masked_parts, unmasked_positions, masked_positions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
protxlstm/index.html DELETED
@@ -1,16 +0,0 @@
1
- <!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 3.2 Final//EN">
2
- <html>
3
- <head>
4
- <title>Index of /research/Bio-xLSTM/downloads/Prot-xLSTM/checkpoints/protxlstm_26M_30B</title>
5
- </head>
6
- <body>
7
- <h1>Index of /research/Bio-xLSTM/downloads/Prot-xLSTM/checkpoints/protxlstm_26M_30B</h1>
8
- <pre><img src="/icons/blank.gif" alt="Icon "> <a href="?C=N;O=D">Name</a> <a href="?C=M;O=A">Last modified</a> <a href="?C=S;O=A">Size</a> <a href="?C=D;O=A">Description</a><hr><img src="/icons/back.gif" alt="[PARENTDIR]"> <a href="/research/Bio-xLSTM/downloads/Prot-xLSTM/checkpoints/">Parent Directory</a> -
9
- <img src="/icons/unknown.gif" alt="[ ]"> <a href="config.json">config.json</a> 2024-11-04 14:36 1.8K
10
- <img src="/icons/unknown.gif" alt="[ ]"> <a href="optimizer.pt">optimizer.pt</a> 2024-11-04 14:36 198M
11
- <img src="/icons/binary.gif" alt="[ ]"> <a href="pytorch_model.bin">pytorch_model.bin</a> 2024-11-04 14:36 99M
12
- <img src="/icons/unknown.gif" alt="[ ]"> <a href="rng_state.pth">rng_state.pth</a> 2024-11-04 14:36 14K
13
- <img src="/icons/unknown.gif" alt="[ ]"> <a href="scheduler.pt">scheduler.pt</a> 2024-11-04 14:36 1.0K
14
- <img src="/icons/unknown.gif" alt="[ ]"> <a href="trainer_state.json">trainer_state.json</a> 2024-11-04 14:36 2.4M
15
- <hr></pre>
16
- </body></html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
protxlstm/models/llama.py DELETED
@@ -1,342 +0,0 @@
1
- import json
2
- import math
3
- import os
4
- from collections import namedtuple
5
- from typing import Optional, Tuple
6
-
7
- import torch
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
- from transformers import PretrainedConfig
11
-
12
- from protxlstm.xlstm.components.rotary_position import compute_freqs_cis
13
-
14
- # Note: generation capabilities are not implemented for the transformer
15
-
16
- class TransformerConfig(PretrainedConfig):
17
-
18
- model_type = "llama"
19
-
20
- def __init__(
21
- self,
22
- d_model,
23
- n_layer,
24
- n_heads,
25
- n_kv_heads,
26
- bidirectional,
27
- vocab_size,
28
- hidden_dim,
29
- multiple_of, # MLP hidden layer size will be multiple of
30
- norm_eps,
31
- max_length,
32
- dropout,
33
- max_position_embeddings,
34
- rope_base_frequency,
35
- **kwargs
36
- ):
37
- super().__init__(**kwargs)
38
-
39
- # default hyperparameters for the Llama 7B model
40
- self.dim = d_model
41
- self.n_layers = n_layer
42
- self.n_heads = n_heads
43
- self.n_kv_heads = n_kv_heads
44
- self.causal_attention = not bidirectional
45
- self.vocab_size = vocab_size
46
- self.hidden_dim = hidden_dim
47
- self.multiple_of = multiple_of
48
- self.norm_eps = norm_eps
49
- self.max_seq_len = max_length
50
- self.dropout = dropout
51
- self.max_position_embeddings = max_position_embeddings
52
- self.rope_base_frequency = rope_base_frequency
53
-
54
- class RMSNorm_transformer(torch.nn.Module):
55
- def __init__(self, dim: int, eps: float):
56
- super().__init__()
57
- self.eps = eps
58
- self.weight = nn.Parameter(torch.ones(dim))
59
-
60
- def _norm(self, x):
61
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
62
-
63
- def forward(self, x):
64
- output = self._norm(x.float()).type_as(x)
65
- return output * self.weight
66
-
67
- def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
68
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
69
- t = torch.arange(end, device=freqs.device) # type: ignore
70
- freqs = torch.outer(t, freqs).float() # type: ignore
71
- freqs_cos = torch.cos(freqs) # real part
72
- freqs_sin = torch.sin(freqs) # imaginary part
73
- return freqs_cos, freqs_sin
74
-
75
- def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
76
- ndim = x.ndim
77
- assert 0 <= 1 < ndim
78
- assert freqs_cis.shape == (x.shape[1], x.shape[-1])
79
- shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
80
- return freqs_cis.view(shape)
81
-
82
- def apply_rotary_emb(
83
- xq: torch.Tensor,
84
- xk: torch.Tensor,
85
- freqs_cos: torch.Tensor,
86
- freqs_sin: torch.Tensor
87
- ) -> Tuple[torch.Tensor, torch.Tensor]:
88
-
89
- # reshape xq and xk to match the complex representation
90
- xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
91
- xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
92
-
93
- # reshape freqs_cos and freqs_sin for broadcasting
94
- freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
95
- freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
96
-
97
- # apply rotation using real numbers
98
- xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
99
- xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
100
- xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
101
- xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
102
-
103
- # flatten last two dimensions
104
- xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
105
- xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
106
-
107
- return xq_out.type_as(xq), xk_out.type_as(xk)
108
-
109
- def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
110
- """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
111
- bs, slen, n_kv_heads, head_dim = x.shape
112
- if n_rep == 1:
113
- return x
114
- return (
115
- x[:, :, :, None, :]
116
- .expand(bs, slen, n_kv_heads, n_rep, head_dim)
117
- .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
118
- )
119
-
120
- class Attention(nn.Module):
121
- def __init__(self, args: TransformerConfig):
122
- super().__init__()
123
- self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
124
- assert args.n_heads % self.n_kv_heads == 0
125
- model_parallel_size = 1
126
- self.n_local_heads = args.n_heads // model_parallel_size
127
- self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
128
- self.n_rep = self.n_local_heads // self.n_local_kv_heads
129
- self.head_dim = args.dim // args.n_heads
130
- self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
131
- self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
132
- self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
133
- self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
134
- self.attn_dropout = nn.Dropout(args.dropout)
135
- self.resid_dropout = nn.Dropout(args.dropout)
136
- self.dropout = args.dropout
137
- self.causal_attention = args.causal_attention
138
-
139
- # use flash attention or a manual implementation?
140
- self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
141
- if not self.flash and self.causal_attention:
142
- print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
143
- mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
144
- mask = torch.triu(mask, diagonal=1)
145
- self.register_buffer("mask", mask)
146
-
147
- def forward(
148
- self,
149
- x: torch.Tensor,
150
- freqs_cos: torch.Tensor,
151
- freqs_sin: torch.Tensor,
152
- ):
153
- bsz, seqlen, _ = x.shape
154
-
155
- # QKV
156
- xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
157
- xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
158
- xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
159
- xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
160
-
161
- # RoPE relative positional embeddings
162
- xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
163
-
164
- # grouped multiquery attention: expand out keys and values
165
- xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
166
- xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
167
-
168
- # make heads into a batch dimension
169
- xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
170
- xk = xk.transpose(1, 2)
171
- xv = xv.transpose(1, 2)
172
-
173
- # flash implementation
174
- if self.flash:
175
- output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=self.causal_attention)
176
- else:
177
- # manual implementation
178
- scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
179
- if self.causal_attention:
180
- scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
181
- scores = F.softmax(scores.float(), dim=-1).type_as(xq)
182
- scores = self.attn_dropout(scores)
183
- output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)
184
-
185
- # restore time as batch dimension and concat heads
186
- output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
187
-
188
- # final projection into the residual stream
189
- output = self.wo(output)
190
- output = self.resid_dropout(output)
191
- return output
192
-
193
- class FeedForward(nn.Module):
194
- def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
195
- super().__init__()
196
- if hidden_dim is None:
197
- hidden_dim = 4 * dim
198
- hidden_dim = int(2 * hidden_dim / 3)
199
- hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
200
- self.w1 = nn.Linear(dim, hidden_dim, bias=False)
201
- self.w2 = nn.Linear(hidden_dim, dim, bias=False)
202
- self.w3 = nn.Linear(dim, hidden_dim, bias=False)
203
- self.dropout = nn.Dropout(dropout)
204
-
205
- def forward(self, x):
206
- return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
207
-
208
- class TransformerBlock(nn.Module):
209
- def __init__(self, layer_id: int, args: TransformerConfig):
210
- super().__init__()
211
- self.n_heads = args.n_heads
212
- self.dim = args.dim
213
- self.head_dim = args.dim // args.n_heads
214
- self.attention = Attention(args)
215
- self.feed_forward = FeedForward(
216
- dim=args.dim,
217
- hidden_dim=args.hidden_dim,
218
- multiple_of=args.multiple_of,
219
- dropout=args.dropout,
220
- )
221
- self.layer_id = layer_id
222
- self.attention_norm = RMSNorm_transformer(args.dim, eps=args.norm_eps)
223
- self.ffn_norm = RMSNorm_transformer(args.dim, eps=args.norm_eps)
224
-
225
- def forward(self, x, freqs_cos, freqs_sin):
226
- h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
227
- out = h + self.feed_forward.forward(self.ffn_norm(h))
228
- return out
229
-
230
- class Transformer(nn.Module):
231
-
232
- last_loss: Optional[torch.Tensor]
233
-
234
- def __init__(self, params: TransformerConfig):
235
- super().__init__()
236
- self.params = params
237
- self.vocab_size = params.vocab_size
238
- self.n_layers = params.n_layers
239
-
240
- self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
241
- self.dropout = nn.Dropout(params.dropout)
242
- self.layers = torch.nn.ModuleList()
243
- for layer_id in range(params.n_layers):
244
- self.layers.append(TransformerBlock(layer_id, params))
245
- self.layer_head_dim = self.layers[0].head_dim
246
-
247
- self.norm = RMSNorm_transformer(params.dim, eps=params.norm_eps)
248
- self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
249
-
250
- # share the unembedding parameters with the embedding parameters
251
- self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying
252
-
253
- # some useful precompute for the RoPE relative positional embeddings
254
- # freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)
255
- # self.register_buffer("freqs_cos", freqs_cos, persistent=False)
256
- # self.register_buffer("freqs_sin", freqs_sin, persistent=False)
257
-
258
- # init all weights
259
- self.apply(self._init_weights)
260
- # apply special scaled init to the residual projections, per GPT-2 paper
261
- for pn, p in self.named_parameters():
262
- if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
263
- torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * params.n_layers))
264
-
265
- # Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor.
266
- self.last_loss = None
267
-
268
- def _init_weights(self, module):
269
- if isinstance(module, nn.Linear):
270
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
271
- if module.bias is not None:
272
- torch.nn.init.zeros_(module.bias)
273
- elif isinstance(module, nn.Embedding):
274
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
275
-
276
- def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
277
- _bsz, seqlen = tokens.shape
278
- h = self.tok_embeddings(tokens)
279
- h = self.dropout(h)
280
- # freqs_cos = self.freqs_cos[:seqlen]
281
- # freqs_sin = self.freqs_sin[:seqlen]
282
-
283
- if 'position_ids' in kwargs:
284
- freqs_cos, freqs_sin = compute_freqs_cis(kwargs.pop("position_ids"), self.layer_head_dim, theta=self.params.rope_base_frequency)
285
- else:
286
- raise ValueError('Llama model only implemented with RoPEs')
287
-
288
- freqs_cos = freqs_cos.squeeze()
289
- freqs_sin = freqs_sin.squeeze()
290
-
291
- for layer in self.layers:
292
- h = layer(h, freqs_cos, freqs_sin)
293
- h = self.norm(h)
294
-
295
- if targets is not None:
296
- logits = self.output(h)
297
- self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
298
- else:
299
- logits = self.output(h)
300
- self.last_loss = None
301
-
302
- return logits
303
-
304
- class TransformerLMHeadModel(nn.Module):
305
-
306
- def __init__(
307
- self,
308
- config: TransformerConfig,
309
- ) -> None:
310
-
311
- super().__init__()
312
-
313
- self.config = config
314
-
315
- self.backbone = Transformer(config)
316
-
317
- def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
318
- """
319
- num_last_tokens: if > 0, only return the logits for the last n tokens
320
- """
321
-
322
- lm_logits = self.backbone(input_ids, position_ids=position_ids)
323
-
324
- CausalLMOutput = namedtuple("CausalLMOutput", ["loss", "logits"])
325
- return CausalLMOutput(loss=None, logits=lm_logits)
326
-
327
- def save_pretrained(self, save_directory):
328
- """
329
- Save the model and its configuration file to a directory.
330
- """
331
-
332
- # Ensure save_directory exists
333
- os.makedirs(save_directory, exist_ok=True)
334
-
335
- # Save the model's state_dict
336
- model_path = os.path.join(save_directory, "pytorch_model.bin")
337
- torch.save(self.state_dict(), model_path)
338
-
339
- # Save the configuration of the model
340
- config_path = os.path.join(save_directory, "config.json")
341
- with open(config_path, "w") as f:
342
- json.dump(self.config.to_dict(), f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
protxlstm/models/mamba.py DELETED
@@ -1,833 +0,0 @@
1
- # Original code from ProtMamba under Apache License 2.0.
2
-
3
- import json
4
- import os
5
- from collections import namedtuple
6
- from dataclasses import dataclass, field
7
- from functools import partial
8
-
9
- from mamba_ssm.models.config_mamba import MambaConfig
10
- from mamba_ssm.modules.mamba_simple import Block, Mamba
11
- from mamba_ssm.models.mixer_seq_simple import MixerModel, _init_weights
12
- from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
13
- from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
14
- import torch
15
- import torch.nn as nn
16
- from torch.utils.checkpoint import checkpoint
17
- from transformers import PretrainedConfig
18
-
19
- from protxlstm.generation import GenerationMixinSafe
20
-
21
- @dataclass
22
- class MambaConfig(PretrainedConfig):
23
- d_model: int = 2560
24
- n_layer: int = 64
25
- vocab_size: int = 50277
26
- ssm_cfg: dict = field(default_factory=dict)
27
- rms_norm: bool = True
28
- residual_in_fp32: bool = True
29
- fused_add_norm: bool = True
30
- pad_vocab_size_multiple: int = 8
31
- max_position_embeddings: int = 2048
32
-
33
- def create_block(
34
- d_model,
35
- ssm_cfg=None,
36
- norm_epsilon=1e-5,
37
- rms_norm=False,
38
- residual_in_fp32=False,
39
- fused_add_norm=False,
40
- layer_idx=None,
41
- device=None,
42
- dtype=None,
43
- checkpoint_mixer=False,
44
- ):
45
- if ssm_cfg is None:
46
- ssm_cfg = {}
47
- factory_kwargs = {"device": device, "dtype": dtype}
48
- mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
49
- norm_cls = partial(
50
- nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
51
- )
52
- block = Block(
53
- d_model,
54
- mixer_cls,
55
- norm_cls=norm_cls,
56
- fused_add_norm=fused_add_norm,
57
- residual_in_fp32=residual_in_fp32,
58
- )
59
- block.layer_idx = layer_idx
60
- if checkpoint_mixer:
61
- block.mixer = CheckpointedModule(block.mixer)
62
- return block
63
-
64
- class CheckpointedModule(torch.nn.Module):
65
- def __init__(self, layer):
66
- super().__init__()
67
- self.ckpt_layer = layer
68
-
69
- def forward(self, x, *args, **kwargs):
70
- return checkpoint(self.ckpt_layer, x, use_reentrant=False)
71
-
72
- # def state_dict(self, **kwargs):
73
- # # Get the state dict of the underlying layer
74
- # layer_state_dict = self.ckpt_layer.state_dict(**kwargs)
75
- # # Create a new state dict with the original keys
76
- # state_dict = {k.replace('ckpt_layer.', ''): v for k, v in layer_state_dict.items()}
77
- # return state_dict
78
-
79
- class MixerModelSafe(MixerModel):
80
- """
81
- Overwrite the forward method to allow saving intermediate layers.
82
- """
83
-
84
- def forward(self, input_ids, inference_params=None, save_layer=[]):
85
- hidden_states = self.embedding(input_ids)
86
- residual = None
87
- if len(save_layer) > 0:
88
- hidden_states_dict = {}
89
- for i, layer in enumerate(self.layers):
90
- hidden_states, residual = layer(
91
- hidden_states, residual, inference_params=inference_params
92
- )
93
- if i + 1 in save_layer:
94
- hidden_states_dict[i + 1] = (
95
- hidden_states.detach().cpu().to(torch.float).numpy()
96
- )
97
- if len(save_layer) > 0:
98
- return hidden_states_dict
99
-
100
- if not self.fused_add_norm:
101
- residual = (
102
- (hidden_states + residual) if residual is not None else hidden_states
103
- )
104
- hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
105
- else:
106
- # Set prenorm=False here since we don't need the residual
107
- fused_add_norm_fn = (
108
- rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
109
- )
110
- hidden_states = fused_add_norm_fn(
111
- hidden_states,
112
- self.norm_f.weight,
113
- self.norm_f.bias,
114
- eps=self.norm_f.eps,
115
- residual=residual,
116
- prenorm=False,
117
- residual_in_fp32=self.residual_in_fp32,
118
- )
119
- return hidden_states
120
-
121
- class MixerModelWithPosids(nn.Module):
122
- r"""Mixer model for Mamba but we add positional encodings to the input embeddings."""
123
-
124
- def __init__(
125
- self,
126
- d_model: int,
127
- n_layer: int,
128
- vocab_size: int,
129
- max_position_embeddings: int,
130
- ssm_cfg=None,
131
- norm_epsilon: float = 1e-5,
132
- rms_norm: bool = False,
133
- initializer_cfg=None,
134
- fused_add_norm=False,
135
- residual_in_fp32=False,
136
- device=None,
137
- dtype=None,
138
- checkpoint_mixer=False,
139
- ) -> None:
140
- factory_kwargs = {"device": device, "dtype": dtype}
141
- super().__init__()
142
- self.residual_in_fp32 = residual_in_fp32
143
-
144
- self.embedding = nn.Embedding(vocab_size, d_model // 2, **factory_kwargs)
145
- self.position_embedding = nn.Embedding(
146
- max_position_embeddings, d_model - d_model // 2, **factory_kwargs
147
- )
148
-
149
- # We change the order of residual and layer norm:
150
- # Instead of LN -> Attn / MLP -> Add, we do:
151
- # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
152
- # the main branch (output of MLP / Mixer). The model definition is unchanged.
153
- # This is for performance reason: we can fuse add + layer_norm.
154
- self.fused_add_norm = fused_add_norm
155
- if self.fused_add_norm:
156
- if layer_norm_fn is None or rms_norm_fn is None:
157
- raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
158
-
159
- self.layers = nn.ModuleList(
160
- [
161
- create_block(
162
- d_model,
163
- ssm_cfg=ssm_cfg,
164
- norm_epsilon=norm_epsilon,
165
- rms_norm=rms_norm,
166
- residual_in_fp32=residual_in_fp32,
167
- fused_add_norm=fused_add_norm,
168
- layer_idx=i,
169
- checkpoint_mixer=checkpoint_mixer,
170
- **factory_kwargs,
171
- )
172
- for i in range(n_layer)
173
- ]
174
- )
175
-
176
- self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
177
- d_model, eps=norm_epsilon, **factory_kwargs
178
- )
179
-
180
- self.apply(
181
- partial(
182
- _init_weights,
183
- n_layer=n_layer,
184
- **(initializer_cfg if initializer_cfg is not None else {}),
185
- )
186
- )
187
-
188
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
189
- return {
190
- i: layer.allocate_inference_cache(
191
- batch_size, max_seqlen, dtype=dtype, **kwargs
192
- )
193
- for i, layer in enumerate(self.layers)
194
- }
195
-
196
- def forward(self, input_ids, position_ids, inference_params=None, save_layer=[]):
197
- hidden_states = torch.cat(
198
- [
199
- self.embedding(input_ids),
200
- self.position_embedding(position_ids),
201
- ],
202
- -1,
203
- )
204
- residual = None
205
- if len(save_layer) > 0:
206
- hidden_states_dict = {}
207
- for i, layer in enumerate(self.layers):
208
- hidden_states, residual = layer(
209
- hidden_states, residual, inference_params=inference_params
210
- )
211
- if i + 1 in save_layer:
212
- hidden_states_dict[i + 1] = (
213
- hidden_states.detach().cpu().to(torch.float).numpy()
214
- )
215
- if len(save_layer) > 0:
216
- return hidden_states_dict
217
-
218
- if not self.fused_add_norm:
219
- residual = (
220
- (hidden_states + residual) if residual is not None else hidden_states
221
- )
222
- hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
223
- else:
224
- fused_add_norm_fn = (
225
- rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
226
- )
227
- hidden_states = fused_add_norm_fn(
228
- hidden_states,
229
- self.norm_f.weight,
230
- self.norm_f.bias,
231
- eps=self.norm_f.eps,
232
- residual=residual,
233
- prenorm=False,
234
- residual_in_fp32=self.residual_in_fp32,
235
- )
236
- return hidden_states
237
-
238
- class MixerModelWith2DPosids(nn.Module):
239
- r"""Mixer model for Mamba but we add positional encodings to the input embeddings."""
240
-
241
- def __init__(
242
- self,
243
- d_model: int,
244
- n_layer: int,
245
- vocab_size: int,
246
- max_position_embeddings: int,
247
- max_sequence_position_embeddings: int = 512,
248
- ssm_cfg=None,
249
- norm_epsilon: float = 1e-5,
250
- rms_norm: bool = False,
251
- initializer_cfg=None,
252
- fused_add_norm=False,
253
- residual_in_fp32=False,
254
- device=None,
255
- dtype=None,
256
- checkpoint_mixer=False,
257
- ) -> None:
258
- factory_kwargs = {"device": device, "dtype": dtype}
259
- super().__init__()
260
- self.residual_in_fp32 = residual_in_fp32
261
-
262
- self.embedding = nn.Embedding(
263
- vocab_size, d_model - 2 * d_model // 4, **factory_kwargs
264
- )
265
- self.position_embedding = nn.Embedding(
266
- max_position_embeddings, d_model // 4, **factory_kwargs
267
- )
268
- self.seq_position_embedding = nn.Embedding(
269
- max_sequence_position_embeddings, d_model // 4, **factory_kwargs
270
- )
271
- self.d_embeddings = d_model - 2 * d_model // 4
272
-
273
- # We change the order of residual and layer norm:
274
- # Instead of LN -> Attn / MLP -> Add, we do:
275
- # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
276
- # the main branch (output of MLP / Mixer). The model definition is unchanged.
277
- # This is for performance reason: we can fuse add + layer_norm.
278
- self.fused_add_norm = fused_add_norm
279
- if self.fused_add_norm:
280
- if layer_norm_fn is None or rms_norm_fn is None:
281
- raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
282
-
283
- self.layers = nn.ModuleList(
284
- [
285
- create_block(
286
- d_model,
287
- ssm_cfg=ssm_cfg,
288
- norm_epsilon=norm_epsilon,
289
- rms_norm=rms_norm,
290
- residual_in_fp32=residual_in_fp32,
291
- fused_add_norm=fused_add_norm,
292
- layer_idx=i,
293
- checkpoint_mixer=checkpoint_mixer,
294
- **factory_kwargs,
295
- )
296
- for i in range(n_layer)
297
- ]
298
- )
299
-
300
- self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
301
- d_model, eps=norm_epsilon, **factory_kwargs
302
- )
303
-
304
- self.apply(
305
- partial(
306
- _init_weights,
307
- n_layer=n_layer,
308
- **(initializer_cfg if initializer_cfg is not None else {}),
309
- )
310
- )
311
-
312
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
313
- return {
314
- i: layer.allocate_inference_cache(
315
- batch_size, max_seqlen, dtype=dtype, **kwargs
316
- )
317
- for i, layer in enumerate(self.layers)
318
- }
319
-
320
- def forward(
321
- self,
322
- input_ids,
323
- position_ids,
324
- seq_position_ids,
325
- inference_params=None,
326
- save_layer=[],
327
- ):
328
- hidden_states = torch.cat(
329
- [
330
- self.embedding(input_ids),
331
- self.position_embedding(position_ids),
332
- self.seq_position_embedding(seq_position_ids),
333
- ],
334
- -1,
335
- )
336
- residual = None
337
- if len(save_layer) > 0:
338
- hidden_states_dict = {}
339
- for i, layer in enumerate(self.layers):
340
- hidden_states, residual = layer(
341
- hidden_states, residual, inference_params=inference_params
342
- )
343
- if i + 1 in save_layer:
344
- hidden_states_dict[i + 1] = (
345
- hidden_states.detach().cpu().to(torch.float).numpy()
346
- )
347
- if len(save_layer) > 0:
348
- return hidden_states_dict
349
-
350
- if not self.fused_add_norm:
351
- residual = (
352
- (hidden_states + residual) if residual is not None else hidden_states
353
- )
354
- hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
355
- else:
356
- fused_add_norm_fn = (
357
- rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
358
- )
359
- hidden_states = fused_add_norm_fn(
360
- hidden_states,
361
- self.norm_f.weight,
362
- self.norm_f.bias,
363
- eps=self.norm_f.eps,
364
- residual=residual,
365
- prenorm=False,
366
- residual_in_fp32=self.residual_in_fp32,
367
- )
368
- return hidden_states
369
-
370
- class MambaLMHeadModelSafe(nn.Module, GenerationMixinSafe):
371
-
372
- def __init__(
373
- self,
374
- config: MambaConfig,
375
- initializer_cfg=None,
376
- device=None,
377
- dtype=None,
378
- checkpoint_mixer=False,
379
- ) -> None:
380
- self.config = config
381
- d_model = config.d_model
382
- n_layer = config.n_layer
383
- vocab_size = config.vocab_size
384
- ssm_cfg = config.ssm_cfg
385
- rms_norm = config.rms_norm
386
- residual_in_fp32 = config.residual_in_fp32
387
- fused_add_norm = config.fused_add_norm
388
- pad_vocab_size_multiple = config.pad_vocab_size_multiple
389
- factory_kwargs = {"device": device, "dtype": dtype}
390
- if checkpoint_mixer:
391
- raise NotImplementedError(
392
- "Checkpointing is not yet supported for MambaLMHeadModelSafe"
393
- )
394
-
395
- super().__init__()
396
- if vocab_size % pad_vocab_size_multiple != 0:
397
- vocab_size += pad_vocab_size_multiple - (
398
- vocab_size % pad_vocab_size_multiple
399
- )
400
- self.backbone = MixerModelSafe(
401
- d_model=d_model,
402
- n_layer=n_layer,
403
- vocab_size=vocab_size,
404
- ssm_cfg=ssm_cfg,
405
- rms_norm=rms_norm,
406
- initializer_cfg=initializer_cfg,
407
- fused_add_norm=fused_add_norm,
408
- residual_in_fp32=residual_in_fp32,
409
- **factory_kwargs,
410
- )
411
- self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
412
-
413
- # Initialize weights and apply final processing
414
- self.apply(
415
- partial(
416
- _init_weights,
417
- n_layer=n_layer,
418
- **(initializer_cfg if initializer_cfg is not None else {}),
419
- )
420
- )
421
- self.tie_weights()
422
-
423
- def tie_weights(self):
424
- self.lm_head.weight = self.backbone.embedding.weight
425
-
426
- def clip_grad_norm_(self, max_norm, norm_type=2.0):
427
- r"""Clip the norm of the gradients for the model.
428
- Args:
429
- max_norm (float or int): The maximum norm of the gradients.
430
- The gradients are modified in-place.
431
- norm_type (float or int): The type of the used p-norm. Can be 'inf' for infinity norm.
432
- Returns:
433
- Total norm of the parameters (viewed as a single vector).
434
- """
435
- return torch.nn.utils.clip_grad_value_(self.parameters(), max_norm)
436
-
437
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
438
- return self.backbone.allocate_inference_cache(
439
- batch_size, max_seqlen, dtype=dtype, **kwargs
440
- )
441
-
442
- def forward(
443
- self,
444
- input_ids,
445
- position_ids=None,
446
- inference_params=None,
447
- num_last_tokens=0,
448
- save_layer=[],
449
- *args,
450
- **kwargs,
451
- ):
452
- """
453
- "position_ids" is just to be compatible with Transformer generation. We don't use it.
454
- num_last_tokens: if > 0, only return the logits for the last n tokens
455
- """
456
- return self.protected_forward(
457
- input_ids, position_ids, inference_params, num_last_tokens, save_layer
458
- )
459
-
460
- def protected_forward(
461
- self,
462
- input_ids,
463
- position_ids=None,
464
- inference_params=None,
465
- num_last_tokens=0,
466
- save_layer=[],
467
- ):
468
- hidden_states = self.backbone(
469
- input_ids, inference_params=inference_params, save_layer=save_layer
470
- )
471
- if len(save_layer) > 0:
472
- return hidden_states
473
- if num_last_tokens > 0:
474
- hidden_states = hidden_states[:, -num_last_tokens:]
475
- lm_logits = self.lm_head(hidden_states)
476
- CausalLMOutput = namedtuple("CausalLMOutput", ["loss", "logits"])
477
- return CausalLMOutput(loss=None, logits=lm_logits)
478
-
479
- @classmethod
480
- def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
481
- config_data = load_config_hf(pretrained_model_name)
482
- config = MambaConfig(**config_data)
483
- model = cls(config, device=device, dtype=dtype, **kwargs)
484
- model.load_state_dict(
485
- load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype),
486
- strict=False,
487
- )
488
- return model
489
-
490
- def save_pretrained(self, save_directory):
491
- """
492
- Minimal implementation of save_pretrained for MambaLMHeadModel.
493
- Save the model and its configuration file to a directory.
494
- """
495
- # Ensure save_directory exists
496
- os.makedirs(save_directory, exist_ok=True)
497
-
498
- # Save the model's state_dict
499
- model_path = os.path.join(save_directory, "pytorch_model.bin")
500
- torch.save(self.state_dict(), model_path)
501
-
502
- # Save the configuration of the model
503
- config_path = os.path.join(save_directory, "config.json")
504
- with open(config_path, "w") as f:
505
- json.dump(self.config.__dict__, f)
506
-
507
- class MambaLMHeadModelwithPosids(nn.Module, GenerationMixinSafe):
508
-
509
- def __init__(
510
- self,
511
- config: MambaConfig,
512
- initializer_cfg=None,
513
- device=None,
514
- dtype=None,
515
- checkpoint_mixer=False,
516
- ) -> None:
517
- self.config = config
518
- d_model = config.d_model
519
- n_layer = config.n_layer
520
- vocab_size = config.vocab_size
521
- max_position_embeddings = config.max_position_embeddings
522
- ssm_cfg = config.ssm_cfg
523
- rms_norm = config.rms_norm
524
- residual_in_fp32 = config.residual_in_fp32
525
- fused_add_norm = config.fused_add_norm
526
- pad_vocab_size_multiple = config.pad_vocab_size_multiple
527
- factory_kwargs = {"device": device, "dtype": dtype}
528
-
529
- super().__init__()
530
- if vocab_size % pad_vocab_size_multiple != 0:
531
- vocab_size += pad_vocab_size_multiple - (
532
- vocab_size % pad_vocab_size_multiple
533
- )
534
- self.backbone = MixerModelWithPosids(
535
- d_model=d_model,
536
- n_layer=n_layer,
537
- vocab_size=vocab_size,
538
- max_position_embeddings=max_position_embeddings,
539
- ssm_cfg=ssm_cfg,
540
- rms_norm=rms_norm,
541
- initializer_cfg=initializer_cfg,
542
- fused_add_norm=fused_add_norm,
543
- residual_in_fp32=residual_in_fp32,
544
- checkpoint_mixer=checkpoint_mixer,
545
- **factory_kwargs,
546
- )
547
- self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
548
-
549
- # Initialize weights and apply final processing
550
- self.apply(
551
- partial(
552
- _init_weights,
553
- n_layer=n_layer,
554
- **(initializer_cfg if initializer_cfg is not None else {}),
555
- )
556
- )
557
- self.tie_weights()
558
-
559
- def tie_weights(self):
560
- self.lm_head.weight = self.backbone.embedding.weight
561
-
562
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
563
- return self.backbone.allocate_inference_cache(
564
- batch_size, max_seqlen, dtype=dtype, **kwargs
565
- )
566
-
567
- def forward(
568
- self,
569
- input_ids,
570
- position_ids=None,
571
- inference_params=None,
572
- num_last_tokens=0,
573
- save_layer=[],
574
- *args,
575
- **kwargs,
576
- ):
577
- """
578
- "position_ids" is just to be compatible with Transformer generation. We don't use it.
579
- num_last_tokens: if > 0, only return the logits for the last n tokens
580
- """
581
- return self.protected_forward(
582
- input_ids, position_ids, inference_params, num_last_tokens, save_layer
583
- )
584
-
585
- def protected_forward(
586
- self,
587
- input_ids,
588
- position_ids=None,
589
- inference_params=None,
590
- num_last_tokens=0,
591
- save_layer=[],
592
- ):
593
- hidden_states = self.backbone(
594
- input_ids,
595
- position_ids=position_ids,
596
- inference_params=inference_params,
597
- save_layer=save_layer,
598
- )
599
- if len(save_layer) > 0:
600
- return hidden_states
601
- hidden_states = hidden_states[:, :, : self.config.d_model // 2]
602
- if num_last_tokens > 0:
603
- hidden_states = hidden_states[:, -num_last_tokens:]
604
- lm_logits = self.lm_head(hidden_states)
605
- CausalLMOutput = namedtuple("CausalLMOutput", ["loss", "logits"])
606
- return CausalLMOutput(loss=None, logits=lm_logits)
607
-
608
- @classmethod
609
- def from_pretrained(
610
- cls,
611
- pretrained_model_name,
612
- device=None,
613
- dtype=None,
614
- checkpoint_mixer=False,
615
- **kwargs,
616
- ):
617
- config_data = load_config_hf(pretrained_model_name)
618
- config = MambaConfig(**config_data)
619
- model = cls(
620
- config,
621
- device=device,
622
- dtype=dtype,
623
- checkpoint_mixer=checkpoint_mixer,
624
- **kwargs,
625
- )
626
- state_dict = load_state_dict_hf(
627
- pretrained_model_name, device=device, dtype=dtype
628
- )
629
- if state_dict.keys() != model.state_dict().keys():
630
- if checkpoint_mixer:
631
- for key in model.state_dict().keys():
632
- if "ckpt_layer" in key:
633
- state_dict[key] = state_dict.pop(key.replace("ckpt_layer.", ""))
634
- print(
635
- "Using a model that was pretrained without gradient checkpointing and now want to use it. Changed the keys of the state_dict to match the model's keys."
636
- )
637
- else:
638
- for key in list(state_dict.keys()):
639
- if "ckpt_layer" in key:
640
- state_dict[key.replace("ckpt_layer.", "")] = state_dict.pop(key)
641
- print(
642
- "Using a model that was pretrained with gradient checkpointing but now do not want to use it. Changed the keys of the state_dict to match the model's keys."
643
- )
644
- assert (
645
- state_dict.keys() == model.state_dict().keys()
646
- ), "The keys of the state_dict do not match the model's keys."
647
- model.load_state_dict(state_dict)
648
- return model
649
-
650
- def save_pretrained(self, save_directory):
651
- """
652
- Minimal implementation of save_pretrained for MambaLMHeadModel.
653
- Save the model and its configuration file to a directory.
654
- """
655
- # Ensure save_directory exists
656
- os.makedirs(save_directory, exist_ok=True)
657
-
658
- # Save the model's state_dict
659
- model_path = os.path.join(save_directory, "pytorch_model.bin")
660
- torch.save(self.state_dict(), model_path)
661
-
662
- # Save the configuration of the model
663
- config_path = os.path.join(save_directory, "config.json")
664
- with open(config_path, "w") as f:
665
- json.dump(self.config.__dict__, f)
666
-
667
- class MambaLMHeadModelwith2DPosids(nn.Module, GenerationMixinSafe):
668
-
669
- def __init__(
670
- self,
671
- config: MambaConfig,
672
- initializer_cfg=None,
673
- device=None,
674
- dtype=None,
675
- checkpoint_mixer=False,
676
- ) -> None:
677
- self.config = config
678
- d_model = config.d_model
679
- n_layer = config.n_layer
680
- vocab_size = config.vocab_size
681
- max_position_embeddings = config.max_position_embeddings
682
- ssm_cfg = config.ssm_cfg
683
- rms_norm = config.rms_norm
684
- residual_in_fp32 = config.residual_in_fp32
685
- fused_add_norm = config.fused_add_norm
686
- pad_vocab_size_multiple = config.pad_vocab_size_multiple
687
- factory_kwargs = {"device": device, "dtype": dtype}
688
-
689
- super().__init__()
690
- if vocab_size % pad_vocab_size_multiple != 0:
691
- vocab_size += pad_vocab_size_multiple - (
692
- vocab_size % pad_vocab_size_multiple
693
- )
694
- self.backbone = MixerModelWith2DPosids(
695
- d_model=d_model,
696
- n_layer=n_layer,
697
- vocab_size=vocab_size,
698
- max_position_embeddings=max_position_embeddings,
699
- ssm_cfg=ssm_cfg,
700
- rms_norm=rms_norm,
701
- initializer_cfg=initializer_cfg,
702
- fused_add_norm=fused_add_norm,
703
- residual_in_fp32=residual_in_fp32,
704
- checkpoint_mixer=checkpoint_mixer,
705
- **factory_kwargs,
706
- )
707
- self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
708
-
709
- # Initialize weights and apply final processing
710
- self.apply(
711
- partial(
712
- _init_weights,
713
- n_layer=n_layer,
714
- **(initializer_cfg if initializer_cfg is not None else {}),
715
- )
716
- )
717
- self.tie_weights()
718
-
719
- def tie_weights(self):
720
- self.lm_head.weight = self.backbone.embedding.weight
721
-
722
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
723
- return self.backbone.allocate_inference_cache(
724
- batch_size, max_seqlen, dtype=dtype, **kwargs
725
- )
726
-
727
- def forward(
728
- self,
729
- input_ids,
730
- position_ids=None,
731
- seq_position_ids=None,
732
- inference_params=None,
733
- num_last_tokens=0,
734
- save_layer=[],
735
- *args,
736
- **kwargs,
737
- ):
738
- """
739
- "position_ids" is just to be compatible with Transformer generation. We don't use it.
740
- num_last_tokens: if > 0, only return the logits for the last n tokens
741
- """
742
- return self.protected_forward(
743
- input_ids,
744
- position_ids,
745
- seq_position_ids,
746
- inference_params,
747
- num_last_tokens,
748
- save_layer,
749
- )
750
-
751
- def protected_forward(
752
- self,
753
- input_ids,
754
- position_ids=None,
755
- seq_position_ids=None,
756
- inference_params=None,
757
- num_last_tokens=0,
758
- save_layer=[],
759
- ):
760
- hidden_states = self.backbone(
761
- input_ids,
762
- position_ids=position_ids,
763
- seq_position_ids=seq_position_ids,
764
- inference_params=inference_params,
765
- save_layer=save_layer,
766
- )
767
- if len(save_layer) > 0:
768
- return hidden_states
769
- hidden_states = hidden_states[:, :, : self.backbone.d_embeddings]
770
- if num_last_tokens > 0:
771
- hidden_states = hidden_states[:, -num_last_tokens:]
772
- lm_logits = self.lm_head(hidden_states)
773
- CausalLMOutput = namedtuple("CausalLMOutput", ["loss", "logits"])
774
- return CausalLMOutput(loss=None, logits=lm_logits)
775
-
776
- @classmethod
777
- def from_pretrained(
778
- cls,
779
- pretrained_model_name,
780
- device=None,
781
- dtype=None,
782
- checkpoint_mixer=False,
783
- **kwargs,
784
- ):
785
- config_data = load_config_hf(pretrained_model_name)
786
- config = MambaConfig(**config_data)
787
- model = cls(
788
- config,
789
- device=device,
790
- dtype=dtype,
791
- checkpoint_mixer=checkpoint_mixer,
792
- **kwargs,
793
- )
794
- state_dict = load_state_dict_hf(
795
- pretrained_model_name, device=device, dtype=dtype
796
- )
797
- if state_dict.keys() != model.state_dict().keys():
798
- if checkpoint_mixer:
799
- for key in model.state_dict().keys():
800
- if "ckpt_layer" in key:
801
- state_dict[key] = state_dict.pop(key.replace("ckpt_layer.", ""))
802
- print(
803
- "Using a model that was pretrained without gradient checkpointing and now want to use it. Changed the keys of the state_dict to match the model's keys."
804
- )
805
- else:
806
- for key in list(state_dict.keys()):
807
- if "ckpt_layer" in key:
808
- state_dict[key.replace("ckpt_layer.", "")] = state_dict.pop(key)
809
- print(
810
- "Using a model that was pretrained with gradient checkpointing but now do not want to use it. Changed the keys of the state_dict to match the model's keys."
811
- )
812
- assert (
813
- state_dict.keys() == model.state_dict().keys()
814
- ), "The keys of the state_dict do not match the model's keys."
815
- model.load_state_dict(state_dict)
816
- return model
817
-
818
- def save_pretrained(self, save_directory):
819
- """
820
- Minimal implementation of save_pretrained for MambaLMHeadModel.
821
- Save the model and its configuration file to a directory.
822
- """
823
- # Ensure save_directory exists
824
- os.makedirs(save_directory, exist_ok=True)
825
-
826
- # Save the model's state_dict
827
- model_path = os.path.join(save_directory, "pytorch_model.bin")
828
- torch.save(self.state_dict(), model_path)
829
-
830
- # Save the configuration of the model
831
- config_path = os.path.join(save_directory, "config.json")
832
- with open(config_path, "w") as f:
833
- json.dump(self.config.__dict__, f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
protxlstm/plot_utils.py DELETED
@@ -1,26 +0,0 @@
1
-
2
- cd = { # use dependent on model-type!!
3
- "xLSTM": "#3073AD",
4
- "Transformers": "#4B9D7A",
5
- "Mamba": "#DF8953",
6
- "S4": "#D275AB",
7
- "Hyena": "#E86A61",
8
- }
9
-
10
- def setup_matplotlib():
11
- import matplotlib.pyplot as plt
12
- from tueplots import bundles, axes
13
- bundles.icml2022()
14
- plt.rcParams.update(bundles.icml2022())
15
- plt.rcParams.update(axes.lines(base_width=0.5))
16
- plt.rcParams["text.usetex"] = False
17
- plt.rcParams['font.family'] = "sans-serif"
18
- plt.rcParams['font.serif'] = 'Arial'
19
- plt.rcParams['legend.edgecolor'] = 'grey'
20
- plt.rcParams['legend.framealpha'] = 0.7
21
- plt.rcParams['lines.linewidth'] = 1.2
22
- plt.rcParams['axes.grid'] = True
23
- plt.rcParams['axes.grid.axis'] = 'both'
24
- plt.rcParams['grid.alpha'] = 0.2
25
- plt.rcParams['axes.grid'] = True
26
- plt.rcParams['axes.prop_cycle'] = plt.cycler(color=cd.values())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
protxlstm/train.py DELETED
@@ -1,338 +0,0 @@
1
- # Original code from ProtMamba under Apache License 2.0.
2
- #
3
- # Modifications made by Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen
4
- # - Extended to training of xlstm and transformer-based models
5
- # - Predefined splits instead of on-the-fly creation
6
- # - Option to overwrite config parameters from the command line
7
- # - wandb logging
8
-
9
- import argparse
10
- import os
11
-
12
- import torch
13
- from omegaconf import OmegaConf
14
- from transformers import TrainingArguments
15
-
16
- from protxlstm.dataloaders import ProteinMemmapDataset, ProteinDataCollator
17
- from protxlstm.models.xlstm import xLSTMConfig, xLSTMLMHeadModel
18
- from protxlstm.models.llama import TransformerConfig, TransformerLMHeadModel
19
- from protxlstm.trainer import ProtTrainer, EarlyStoppingCallback, get_last_checkpoint
20
- from protxlstm.utils import (
21
- AA_TO_ID,
22
- compute_metrics,
23
- is_zero_rank,
24
- parse_override_args,
25
- print_number_of_parameters,
26
- print_zero_rank,
27
- set_optimizer_and_scheduler,
28
- setup_wandb,
29
- load_model,
30
- )
31
-
32
- def run(config):
33
- """
34
- Run training loop.
35
-
36
- Args:
37
- config (dict): dictionary with the configuration parameters.
38
- """
39
-
40
- if config.model_type == 'llama':
41
- pe_kwargs = {
42
- 'max_position_embeddings' : config["model"]["max_position_embeddings"],
43
- 'add_position_ids' : '1d',
44
- }
45
- elif config.model_type == 'mamba':
46
- from protxlstm.models.mamba import MambaConfig, MambaLMHeadModelSafe, MambaLMHeadModelwithPosids, MambaLMHeadModelwith2DPosids
47
- pe_kwargs = {
48
- 'max_position_embeddings' : config["model"]["max_position_embeddings"],
49
- 'max_seq_position_embeddings' : config["model"]["max_seq_position_embeddings"],
50
- 'add_position_ids' : config["model"]["add_position_ids"]
51
- }
52
- else:
53
- position_embeddings = config["model"]["position_embeddings"]
54
- assert position_embeddings in ["none", "abs_1d", "abs_2d", "rot_1d", "rot_2d"]
55
- if position_embeddings != "none":
56
- position_embeddings = position_embeddings.split("_")[-1]
57
- pe_kwargs = {
58
- 'max_position_embeddings' : config["model"]["max_position_embeddings"],
59
- 'max_seq_position_embeddings' : config["model"]["max_seq_position_embeddings"],
60
- 'add_position_ids' : position_embeddings
61
- }
62
-
63
- # Setup WandB
64
- wandb_run_name = setup_wandb(config)
65
-
66
- # Load datasets
67
- dataset_params = {
68
- "msa_memmap_path": config["msa_memmap_path"],
69
- "msa_memmap_meta_path": config["msa_memmap_meta_path"],
70
- "sample": config["sample_sequences"],
71
- "max_msa_len": config["max_msa_len"],
72
- "reverse": False,
73
- "seed": config["seed_sequence_sampling"],
74
- "troubleshoot": False,
75
- "fim_strategy": config["fim_strategy"],
76
- "always_mask": config["always_mask"],
77
- **pe_kwargs,
78
- }
79
- train_dataset = ProteinMemmapDataset(subset_path=config["train_set"], **dataset_params)
80
- valid_dataset = ProteinMemmapDataset(subset_path=config["valid_set"], **dataset_params)
81
- train_eval_dataset = ProteinMemmapDataset(subset_path=config["train_eval_set"], **dataset_params)
82
-
83
- print(f'Train set size: {len(train_dataset)} Train eval set size: {len(train_eval_dataset)} Valid set size: {len(valid_dataset)}')
84
-
85
- assert (
86
- len(AA_TO_ID) == config["model"]["vocab_size"]
87
- ), f"Vocab size in the config file does not match the one in the code. I should be {len(AA_TO_ID)}"
88
-
89
- # Create data collator for batched training
90
- data_collator = ProteinDataCollator(max_sequence_length=config["max_msa_len"])
91
-
92
- # Check datatypes
93
- if config["dtype"] == "float32":
94
- dtype = torch.float32
95
- elif config["dtype"] == "bfloat16":
96
- dtype = torch.bfloat16
97
- else:
98
- raise ValueError("dtype must be either float32 or bfloat16")
99
-
100
- # Initialize model
101
- if config.model_type == 'xlstm':
102
-
103
- # Load model for finetuning
104
- if config.finetune_model_path:
105
- # These fields are updated in the config loaded from the checkpoint
106
- config_update_kwargs = {
107
- "mlstm_backend": config["model"]["mlstm_block"]["mlstm"]["backend"],
108
- "mlstm_chunksize": config["model"]["mlstm_block"]["mlstm"]["chunk_size"],
109
- "checkpoint_blocks": config["model"]["checkpoint_blocks"],
110
- "rope_base_frequency": config["model"]["rope_base_frequency"]
111
- }
112
- model = load_model(
113
- config.finetune_model_path,
114
- model_class=xLSTMLMHeadModel,
115
- device="cuda",
116
- dtype=dtype,
117
- **config_update_kwargs
118
- )
119
- else:
120
- # Create new mode
121
- xlstm_config = xLSTMConfig().init_from_dict(config["model"])
122
- model = xLSTMLMHeadModel(xlstm_config)
123
-
124
- elif config.model_type == 'mamba':
125
-
126
- _mamba_model = {
127
- "none": MambaLMHeadModelSafe,
128
- "1d": MambaLMHeadModelwithPosids,
129
- "2d": MambaLMHeadModelwith2DPosids,
130
- }
131
- Mamba = _mamba_model[config['model']["add_position_ids"]]
132
-
133
- # Load model for finetuning
134
- if config.finetune_model_path:
135
- model = load_model(
136
- config.finetune_model_path,
137
- model_class=Mamba,
138
- device="cuda",
139
- dtype=dtype,
140
- checkpoint_mixer=config["checkpoint_mixer"],
141
- )
142
- else:
143
- # Create new mode
144
- mamba_config = MambaConfig(d_model=config['model']["d_model"],
145
- n_layer=config['model']["n_layer"],
146
- vocab_size=config['model']["vocab_size"],
147
- residual_in_fp32=config['model']["residual_in_fp32"])
148
- model = Mamba(mamba_config, dtype=dtype, checkpoint_mixer=config['model']["checkpoint_mixer"])
149
-
150
- elif config.model_type == 'llama':
151
-
152
- llama_config = TransformerConfig(
153
- d_model=config["model"]["d_model"],
154
- n_layer=config["model"]["n_layer"],
155
- n_heads=config["model"]["n_heads"],
156
- n_kv_heads=config["model"]["n_kv_heads"],
157
- bidirectional=config["model"]["bidirectional"],
158
- hidden_dim=config["model"]["hidden_dim"],
159
- multiple_of=config["model"]["multiple_of"],
160
- norm_eps=config["model"]["norm_eps"],
161
- max_length=config["model"]["max_length"],
162
- vocab_size=config["model"]["vocab_size"],
163
- dropout=config["model"]["dropout"],
164
- max_position_embeddings=config["model"]["max_position_embeddings"],
165
- rope_base_frequency=config["model"]["rope_base_frequency"],
166
-
167
- )
168
-
169
- model = TransformerLMHeadModel(llama_config)
170
-
171
- else:
172
- raise ValueError(f"Unsupported model_type: {config.model_type}. Expected 'xlstm', 'mamba', or 'llama'.")
173
-
174
-
175
- # TODO: Improve what we want print
176
- if is_zero_rank():
177
- print_number_of_parameters(model)
178
- print_zero_rank(f"dtype: {config['dtype']}")
179
- print_zero_rank(f"Epochs: {config['num_epochs']}")
180
- print_zero_rank(f"Batch size per GPU: {config['batch_size']}")
181
- print_zero_rank(f"Gradient accumulation steps: {config['gradient_accumulation_steps']}")
182
- eff_batch_size = config["batch_size"] * config["gradient_accumulation_steps"]
183
- nr_gpus = torch.cuda.device_count()
184
- print_zero_rank(f"GPUS: {nr_gpus}")
185
- eff_batch_size *= nr_gpus
186
- print_zero_rank(f"Effective batch size: {eff_batch_size}")
187
- print_zero_rank(
188
- f"Steps per training epoch: {len(train_dataset) // config['batch_size']}, eff. steps: {len(train_dataset) // eff_batch_size}"
189
- )
190
- print_zero_rank(f"Steps per evaluation epoch: {len(valid_dataset) // config['batch_size']}")
191
- print_zero_rank(f"Max MSA length: {config['max_msa_len']}")
192
- ev_epochs = round(
193
- config["eval_steps"] * config["batch_size"] / len(train_dataset), 3
194
- )
195
- print_zero_rank(
196
- f"Evaluation every {config['eval_steps']} steps, i.e. {ev_epochs} epochs. Effectively every {config['eval_steps']*config['gradient_accumulation_steps']} steps, i.e. {ev_epochs*config['gradient_accumulation_steps']} epochs."
197
- )
198
- if config.model_type == 'xlstm' and config["model"]["checkpoint_blocks"]:
199
- print_zero_rank("Using gradient checkpointing")
200
- if config["compute_only_fim_loss"]:
201
- print_zero_rank("Computing only FIM loss for training")
202
-
203
- # Training callbacks
204
- es_callback = EarlyStoppingCallback(
205
- train_path=config["output_dir"] + '/' + wandb_run_name, config=config
206
- )
207
- callbacks = [es_callback]
208
-
209
- # Optimizer and Schedulers
210
- optimizer, scheduler = set_optimizer_and_scheduler(
211
- config,
212
- len(train_dataset),
213
- model.parameters()
214
- )
215
-
216
- # Find checkpoint if available
217
- last_checkpoint = None
218
- if config.finetune_model_path is None:
219
- path = os.path.join(config["output_dir"], wandb_run_name)
220
- if os.path.exists(path):
221
- last_checkpoint = get_last_checkpoint(path)
222
- if last_checkpoint is None:
223
- print_zero_rank("No checkpoint found, starting training from scratch.")
224
- else:
225
- print_zero_rank(f"Resuming training from the last checkpoint: {last_checkpoint}")
226
-
227
- # Create trainer
228
- trainer = ProtTrainer(
229
- model=model,
230
- train_dataset=train_dataset,
231
- eval_dataset={"valid": valid_dataset, "train": train_eval_dataset},
232
- optimizers=(optimizer, scheduler),
233
- args=TrainingArguments(
234
- run_name=wandb_run_name,
235
- local_rank=int(os.getenv('LOCAL_RANK', '0')),
236
- learning_rate=config["learning_rate"],
237
- num_train_epochs=config["num_epochs"],
238
- per_device_train_batch_size=config["batch_size"],
239
- per_device_eval_batch_size=config["batch_size"],
240
- gradient_accumulation_steps=config["gradient_accumulation_steps"],
241
- eval_accumulation_steps=config["eval_accumulation_steps"],
242
- eval_strategy="steps",
243
- max_grad_norm=config["max_grad_norm"],
244
- bf16=config["dtype"] == "bfloat16",
245
- dataloader_num_workers=32,
246
- logging_steps=config["logging_steps"],
247
- eval_steps=config["eval_steps"],
248
- save_steps=config["save_steps"],
249
- output_dir=config["output_dir"] + '/' + wandb_run_name,
250
- logging_dir=config["output_dir"] + '/' + wandb_run_name,
251
- report_to="wandb" if is_zero_rank() else None,
252
- log_on_each_node=False,
253
- overwrite_output_dir=False,
254
- push_to_hub=False,
255
- label_names=["labels"],
256
- ),
257
- compute_only_fim_loss=config["compute_only_fim_loss"],
258
- data_collator=data_collator,
259
- compute_metrics=compute_metrics,
260
- callbacks=callbacks,
261
- )
262
-
263
- # Train model
264
- while True:
265
- if last_checkpoint is None and trainer.state.global_step == 0:
266
- eval_results = trainer.evaluate()
267
- print_zero_rank(
268
- f">>> Initial validation perplexity: {eval_results['eval_valid_perplexity/batch']:.2f}"
269
- )
270
- else:
271
- print_zero_rank(f"Resuming training from the last checkpoint: {last_checkpoint}")
272
- # Train
273
- trainer.train(resume_from_checkpoint=last_checkpoint)
274
-
275
- # Break training when the number of epochs is reached
276
- if (
277
- not es_callback.should_restart
278
- or trainer.state.epoch >= config["num_epochs"]
279
- ):
280
- eval_results = trainer.evaluate()
281
- print_zero_rank(
282
- f">>> Final Perplexity: {eval_results['eval_valid_perplexity/batch']:.2f}"
283
- )
284
- break
285
- # If the training was interrupted because of a loss spike, restart from the last checkpoint
286
- last_checkpoint = es_callback.checkpoint_path
287
-
288
- return trainer
289
-
290
- if __name__ == "__main__":
291
-
292
- # Default configuration file paths
293
- default_model_config = "configs/xlstm_default_config.yaml"
294
- default_train_config = "configs/train_default_config.yaml"
295
-
296
- parser = argparse.ArgumentParser(
297
- description="Train or finetune a model with the provided configuration."
298
- )
299
- parser.add_argument(
300
- "--model_config_path",
301
- type=str,
302
- default=default_model_config,
303
- help=f"Path to the model configuration file (default: {default_model_config})"
304
- )
305
- parser.add_argument(
306
- "--train_config_path",
307
- type=str,
308
- default=default_train_config,
309
- help=f"Path to the training and dataset configuration file (default: {default_train_config})"
310
- )
311
- parser.add_argument(
312
- "overrides",
313
- nargs=argparse.REMAINDER,
314
- help="Override configuration values using key=value format.",
315
- )
316
-
317
- args = parser.parse_args()
318
-
319
- # Check if the default config files exist, or raise an error
320
- if not os.path.exists(args.model_config_path):
321
- raise FileNotFoundError(f"Model config file not found: {args.model_config_path}")
322
- if not os.path.exists(args.train_config_path):
323
- raise FileNotFoundError(f"Train config file not found: {args.train_config_path}")
324
-
325
- # Load the model and training configurations
326
- model_config = OmegaConf.load(args.model_config_path)
327
- train_config = OmegaConf.load(args.train_config_path)
328
-
329
- # Merge the model and training configurations
330
- config = OmegaConf.merge(model_config, train_config)
331
-
332
- # Parse overrides
333
- if args.overrides:
334
- overrides = parse_override_args(args.overrides)
335
- config.merge_with(OmegaConf.create(overrides))
336
-
337
- # Run the training/finetuning process
338
- run(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
protxlstm/trainer.py DELETED
@@ -1,123 +0,0 @@
1
- # Original code from ProtMamba under Apache License 2.0.
2
- #
3
- # Modifications made by Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen
4
- # - MambaTrainer renamed to ProtTrainer
5
-
6
- import os
7
- import re
8
-
9
- import torch
10
- from transformers import Trainer, TrainerCallback
11
-
12
- from protxlstm.utils import AA_TO_ID, find_fim_indices
13
-
14
- class ProtTrainer(Trainer):
15
- """
16
- Base HuggingFace Trainer used for training.
17
-
18
- from https://github.com/havenhq/mamba-chat/blob/main/trainer/mamba_trainer.py"""
19
- def __init__(self, compute_only_fim_loss, **kwargs,):
20
- super().__init__(**kwargs)
21
- self.compute_only_fim_loss = compute_only_fim_loss
22
-
23
-
24
- def compute_loss(self, model, inputs, return_outputs=False):
25
- input_ids = inputs.pop("input_ids")
26
- labels = inputs.pop("labels")
27
- if "seq_position_ids" in inputs and "position_ids" in inputs:
28
- position_ids = inputs.pop("position_ids")
29
- seq_position_ids = inputs.pop("seq_position_ids")
30
- output = model(input_ids, position_ids=position_ids, seq_position_ids=seq_position_ids)
31
- elif "position_ids" in inputs:
32
- position_ids = inputs.pop("position_ids")
33
- output = model(input_ids, position_ids=position_ids)
34
- else:
35
- output = model(input_ids)
36
- lm_logits = output.logits
37
-
38
- labels = labels.to(lm_logits.device)
39
- shift_logits = lm_logits[:, :-1, :].contiguous()
40
- labels = labels[:, 1:].contiguous()
41
-
42
- loss_fct = torch.nn.CrossEntropyLoss()
43
- if self.compute_only_fim_loss:
44
- # start and end tokens
45
- is_cls_tokens = (labels == AA_TO_ID["<cls>"])
46
- is_eos_tokens = (labels == AA_TO_ID["<eos>"])
47
- bool_fim = find_fim_indices(is_cls_tokens, is_eos_tokens)
48
- # include also the cls token
49
- bool_fim = bool_fim | is_cls_tokens
50
- inds = torch.where(bool_fim)
51
- lm_loss = loss_fct(shift_logits[inds[0], inds[1], :], labels[bool_fim])
52
- else:
53
- lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
54
-
55
- return (lm_loss, output) if return_outputs else lm_loss
56
-
57
- def save_model(self, output_dir, _internal_call):
58
- if int(os.getenv('LOCAL_RANK', '0')) == 0:
59
- self.model.save_pretrained(output_dir)
60
-
61
- PREFIX_CHECKPOINT_DIR = "checkpoint"
62
- _re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$")
63
-
64
- def get_last_checkpoint(folder, max_steps=None):
65
- content = os.listdir(folder)
66
- checkpoints = [
67
- path
68
- for path in content
69
- if _re_checkpoint.search(path) is not None and os.path.isdir(os.path.join(folder, path))
70
- ]
71
- if len(checkpoints) == 0:
72
- return
73
-
74
- max_steps = max_steps if max_steps is not None else float("inf")
75
- # func = lambda x: int(_re_checkpoint.search(x).groups()[0])
76
- def func(x):
77
- num = int(_re_checkpoint.search(x).groups()[0])
78
- return num if num < max_steps else -1
79
- return os.path.join(folder, max(checkpoints, key=func))
80
-
81
- class EarlyStoppingCallback(TrainerCallback):
82
- def __init__(self, train_path, config=None):
83
- self.step_counter_reset = 0
84
- self.step_counter_stop = 0
85
- self.best_loss = None
86
- self.train_path = train_path
87
- self.patience = config["patience"]
88
- self.metric_name = config["early_stopping_metric"]
89
- self.checkpoint_path = None
90
- self.should_restart = False
91
- self.eval_steps = config["eval_steps"]
92
- self.loss_increase_factor = config["loss_increase_factor"]
93
-
94
- def get_checkpoint_path(self, max_steps):
95
- last_checkpoint = None
96
- if os.path.exists(self.train_path):
97
- last_checkpoint = get_last_checkpoint(self.train_path, max_steps)
98
- if last_checkpoint is None:
99
- print("No checkpoint found, starting training from scratch.")
100
- else:
101
- print(f"Max checkpoint allowed: {max_steps}, restarting from {last_checkpoint}.")
102
- return last_checkpoint
103
-
104
- def on_evaluate(self, args, state, control, model, metrics, **kwargs):
105
- if self.metric_name in metrics:
106
- if self.best_loss is None:
107
- self.best_loss = metrics[self.metric_name]
108
- elif self.best_loss*self.loss_increase_factor < metrics[self.metric_name]:
109
- self.step_counter += 1
110
- if self.step_counter >= self.patience:
111
- checkpoint_path = self.get_checkpoint_path(max_steps=(state.global_step-self.patience*self.eval_steps))
112
- control.should_training_stop = True
113
- self.checkpoint_path = checkpoint_path
114
- self.should_restart = True
115
- else:
116
- self.step_counter = 0
117
- self.best_loss = min(self.best_loss, metrics[self.metric_name])
118
- self.should_restart = False
119
-
120
- def on_train_begin(self, args, state, control, **kwargs):
121
- self.step_counter = 0
122
- self.best_loss = None
123
- self.should_restart = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
run.sh DELETED
@@ -1,6 +0,0 @@
1
- #!/bin/bash
2
- CONDA_ENV=$(head -1 /code/environment.yml | cut -d" " -f2)
3
- eval "$(conda shell.bash hook)"
4
- conda activate $CONDA_ENV
5
-
6
- streamlit run app.py --server.port 7860 --server.address 0.0.0.0