Elias Buerger
commited on
Commit
·
28f312f
1
Parent(s):
e4995f0
slim down
Browse files- Dockerfile +0 -34
- app.py +2 -0
- environment.yml +0 -39
- protxlstm/applications/generation_utils/create_sequence_df.py +0 -85
- protxlstm/applications/generation_utils/score_hamming.py +0 -80
- protxlstm/applications/generation_utils/score_hmmer.py +0 -102
- protxlstm/applications/generation_utils/score_structure.py +0 -55
- protxlstm/applications/sample_sequences.py +0 -200
- protxlstm/applications/score_sequences.py +0 -58
- protxlstm/data.py +0 -60
- protxlstm/dataloaders.py +0 -249
- protxlstm/fim.py +0 -203
- protxlstm/index.html +0 -16
- protxlstm/models/llama.py +0 -342
- protxlstm/models/mamba.py +0 -833
- protxlstm/plot_utils.py +0 -26
- protxlstm/train.py +0 -338
- protxlstm/trainer.py +0 -123
- run.sh +0 -6
Dockerfile
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
FROM continuumio/anaconda3:main
|
2 |
-
|
3 |
-
WORKDIR /code
|
4 |
-
COPY ./environment.yml /code/environment.yml
|
5 |
-
|
6 |
-
# Create the environment using the environment.yml file
|
7 |
-
RUN conda env create -f /code/environment.yml
|
8 |
-
|
9 |
-
# Set up a new user named "user" with user ID 1000
|
10 |
-
RUN useradd -m -u 1000 user
|
11 |
-
# Switch to the "user" user
|
12 |
-
USER user
|
13 |
-
# Set home to the user's home directory
|
14 |
-
ENV HOME=/home/user \
|
15 |
-
PYTHONPATH=$HOME/app \
|
16 |
-
PYTHONUNBUFFERED=1 \
|
17 |
-
GRADIO_ALLOW_FLAGGING=never \
|
18 |
-
GRADIO_NUM_PORTS=1 \
|
19 |
-
GRADIO_SERVER_NAME=0.0.0.0 \
|
20 |
-
GRADIO_THEME=huggingface \
|
21 |
-
SYSTEM=spaces
|
22 |
-
|
23 |
-
# Set the working directory to the user's home directory
|
24 |
-
WORKDIR $HOME/app
|
25 |
-
|
26 |
-
# Copy the current directory contents into the container at $HOME/app setting the owner to the user
|
27 |
-
COPY --chown=user . $HOME/app
|
28 |
-
|
29 |
-
# cgjs, u+x
|
30 |
-
RUN chmod u+x $HOME/app/run.sh
|
31 |
-
RUN chmod -R 777 $HOME/
|
32 |
-
|
33 |
-
|
34 |
-
CMD ["./run.sh"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -111,11 +111,13 @@ if __name__ == "__main__":
|
|
111 |
placeholder=DEFAULT_SEQUENCE,
|
112 |
)
|
113 |
st.session_state.context_sequences = context_sequence_str.split(",") + [st.session_state.target_sequence]
|
|
|
114 |
elif context_type == 'Use MSA file':
|
115 |
msa_file = st.file_uploader("Choose MSA file")
|
116 |
st.session_state.num_context_sequences = st.number_input("How many of these sequences should be used?", min_value=0, step=1, value=25)
|
117 |
else:
|
118 |
st.session_state.context_sequences = [st.session_state.target_sequence]
|
|
|
119 |
|
120 |
if st.session_state.target_sequence != "":
|
121 |
with st.container():
|
|
|
111 |
placeholder=DEFAULT_SEQUENCE,
|
112 |
)
|
113 |
st.session_state.context_sequences = context_sequence_str.split(",") + [st.session_state.target_sequence]
|
114 |
+
msa_file = None
|
115 |
elif context_type == 'Use MSA file':
|
116 |
msa_file = st.file_uploader("Choose MSA file")
|
117 |
st.session_state.num_context_sequences = st.number_input("How many of these sequences should be used?", min_value=0, step=1, value=25)
|
118 |
else:
|
119 |
st.session_state.context_sequences = [st.session_state.target_sequence]
|
120 |
+
msa_file = None
|
121 |
|
122 |
if st.session_state.target_sequence != "":
|
123 |
with st.container():
|
environment.yml
DELETED
@@ -1,39 +0,0 @@
|
|
1 |
-
name: prot_xlstm_app
|
2 |
-
channels:
|
3 |
-
- pytorch
|
4 |
-
- nvidia
|
5 |
-
- conda-forge
|
6 |
-
- defaults
|
7 |
-
dependencies:
|
8 |
-
- cuda=12.1
|
9 |
-
- cuda-nvcc=12.1
|
10 |
-
- gxx_linux-64=11.2.0
|
11 |
-
- python=3.11
|
12 |
-
- pip
|
13 |
-
- pytorch=2.2.0
|
14 |
-
- pytorch-cuda=12.1
|
15 |
-
- cmake
|
16 |
-
- ninja
|
17 |
-
- pip:
|
18 |
-
- accelerate>=0.26.0
|
19 |
-
- biopython #==1.83
|
20 |
-
- bottleneck #==1.4.2
|
21 |
-
- dacite #==1.8.1
|
22 |
-
- ipykernel #==6.29.3
|
23 |
-
- mamba_ssm==1.2.0
|
24 |
-
- matplotlib #==3.8.4
|
25 |
-
- numpy<2.0 #==1.26.4
|
26 |
-
- omegaconf #==2.3.0
|
27 |
-
- pandas #==2.2.2
|
28 |
-
- pyhmmer #==0.10.15
|
29 |
-
- rich #==13.7.1
|
30 |
-
- scipy #==1.13.0
|
31 |
-
- seaborn #==0.13.2
|
32 |
-
- torchmetrics #==1.2.1
|
33 |
-
- tqdm #==4.66.4
|
34 |
-
- transformers==4.44.2
|
35 |
-
- tueplots #==0.0.17
|
36 |
-
- wandb #==0.17.0
|
37 |
-
- streamlit #==1.43.2
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
protxlstm/applications/generation_utils/create_sequence_df.py
DELETED
@@ -1,85 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import pickle
|
3 |
-
import pandas as pd
|
4 |
-
|
5 |
-
from protxlstm.dataloaders import ProteinMemmapDataset
|
6 |
-
from protxlstm.utils import decode_sequence, reorder_masked_sequence
|
7 |
-
|
8 |
-
|
9 |
-
def create_sequence_df(model_name, family_idx, parameters_list=None, num_sequences = 100, data_dir="./data/"):
|
10 |
-
|
11 |
-
#load dataset
|
12 |
-
dataset = ProteinMemmapDataset(
|
13 |
-
msa_memmap_path=f"{data_dir}open_protein_set_memmap.dat",
|
14 |
-
msa_memmap_meta_path=f"{data_dir}open_protein_set_memmap_indices.csv",
|
15 |
-
subset_path=f"{data_dir}/cluster_testing_set.txt",
|
16 |
-
sample=False,
|
17 |
-
max_msa_len=-1,
|
18 |
-
reverse=False,
|
19 |
-
seed=0,
|
20 |
-
troubleshoot=False,
|
21 |
-
fim_strategy="multiple_span",
|
22 |
-
always_mask=False,
|
23 |
-
max_position_embeddings=2048,
|
24 |
-
max_seq_position_embeddings=512,
|
25 |
-
add_position_ids="1d",
|
26 |
-
mask_fraction=0.2,
|
27 |
-
max_patches=5,
|
28 |
-
)
|
29 |
-
|
30 |
-
family_id = list(dataset.dataset_meta["msa_id"])[family_idx]
|
31 |
-
|
32 |
-
if model_name == "natural":
|
33 |
-
|
34 |
-
data = dataset[family_idx]
|
35 |
-
sequence_df = pd.DataFrame(columns=["family", "family_id", "sequence", "sequence_length"])
|
36 |
-
tokens = data["input_ids"][None,:]
|
37 |
-
all_context = decode_sequence(tokens[0].cpu().numpy())
|
38 |
-
list_sequences_msa = [reorder_masked_sequence(elem+"<cls>") for elem in all_context.split("<cls>")[1:-1]]
|
39 |
-
|
40 |
-
rd_idxs = np.random.choice(len(list_sequences_msa), num_sequences, replace=False)
|
41 |
-
natural_sequences = [seq for i, seq in enumerate(list_sequences_msa) if i in rd_idxs]
|
42 |
-
|
43 |
-
df_dict = {"family": [family_idx]*len(natural_sequences),
|
44 |
-
"family_id": [family_id]*len(natural_sequences),
|
45 |
-
"sequence": natural_sequences,
|
46 |
-
"sequence_length": [len(seq) for seq in natural_sequences]}
|
47 |
-
|
48 |
-
sequence_df = pd.concat([sequence_df, pd.DataFrame(df_dict)], ignore_index = True)
|
49 |
-
|
50 |
-
else:
|
51 |
-
|
52 |
-
sequence_df = pd.DataFrame(columns=["family", "family_id", "n_seqs_ctx", "temperature", "top_k", "top_p", "original_sequence", "sequence", "sequence_length", "perplexity"])
|
53 |
-
|
54 |
-
if parameters_list is None:
|
55 |
-
parameters_list = [(10,1.,10,1.), (10,1.,15,1.), (10,1.,10,0.95), (10,0.9,10,0.95), (10,0.8,10,0.9),
|
56 |
-
(100,1.,10,1.), (100,1.,15,1.), (100,1.,10,0.95), (100,0.9,10,0.95), (100,0.8,10,0.9),
|
57 |
-
(500,1.,10,1.), (500,1.,15,1.), (500,1.,10,0.95), (500,0.9,10,0.95), (500,0.8,10,0.9),
|
58 |
-
(1000,1.,10,1.), (1000,1.,15,1.), (1000,1.,10,0.95), (1000,0.9,10,0.95), (1000,0.8,10,0.9),
|
59 |
-
(-1,1.,10,1.), (-1,1.,15,1.), (-1,1.,10,0.95), (-1,0.9,10,0.95), (-1,0.8,10,0.9)]
|
60 |
-
|
61 |
-
for param in parameters_list:
|
62 |
-
n_seqs_ctx, temperature, top_k, top_p = param
|
63 |
-
|
64 |
-
with open(f"evaluation/generation/generated_sequences/{model_name}/{family_idx}_{param}_{num_sequences}", "rb") as f:
|
65 |
-
gen_seqs = pickle.load(f)
|
66 |
-
|
67 |
-
original_sequences = list(gen_seqs[family_idx][param].keys())
|
68 |
-
reordered_sequences = [reorder_masked_sequence(seq) for seq in original_sequences]
|
69 |
-
perplexities = [gen_seqs[family_idx][param][seq]["perplexity"] for seq in original_sequences]
|
70 |
-
df_dict = {"family": [family_idx]*len(original_sequences),
|
71 |
-
"family_id": [family_id]*len(original_sequences),
|
72 |
-
"n_seqs_ctx": [n_seqs_ctx]*len(original_sequences),
|
73 |
-
"temperature": [temperature]*len(original_sequences),
|
74 |
-
"top_k": [top_k]*len(original_sequences),
|
75 |
-
"top_p": [top_p]*len(original_sequences),
|
76 |
-
"original_sequence": original_sequences,
|
77 |
-
"sequence": reordered_sequences,
|
78 |
-
"sequence_length": [len(seq) for seq in reordered_sequences],
|
79 |
-
"perplexity": perplexities
|
80 |
-
}
|
81 |
-
|
82 |
-
sequence_df = pd.concat([sequence_df, pd.DataFrame(df_dict)], ignore_index = True)
|
83 |
-
|
84 |
-
return sequence_df
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
protxlstm/applications/generation_utils/score_hamming.py
DELETED
@@ -1,80 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
from tqdm import tqdm
|
3 |
-
import pandas as pd
|
4 |
-
from Bio import Align
|
5 |
-
|
6 |
-
from protxlstm.dataloaders import ProteinMemmapDataset
|
7 |
-
from protxlstm.utils import decode_sequence, reorder_masked_sequence
|
8 |
-
|
9 |
-
|
10 |
-
aligner = Align.PairwiseAligner()
|
11 |
-
aligner.mode = 'global'
|
12 |
-
aligner.match_score = 1
|
13 |
-
aligner.mismatch_score = -1
|
14 |
-
aligner.open_gap_score = -1
|
15 |
-
aligner.extend_gap_score = -1
|
16 |
-
|
17 |
-
def align_sequences(ref_seq, query_seq, print_alignments=False):
|
18 |
-
def hamming_str(s1,s2):
|
19 |
-
assert len(s1) == len(s2)
|
20 |
-
return sum(np.array(list(s1)) != np.array(list(s2)))/len(s1)
|
21 |
-
alignments = aligner.align(ref_seq, query_seq)
|
22 |
-
if print_alignments:
|
23 |
-
print("Score = %.1f:" % alignments[0].score)
|
24 |
-
print(alignments[0])
|
25 |
-
return hamming_str(alignments[0][0], alignments[0][1]), alignments[0][0], alignments[0][1]
|
26 |
-
|
27 |
-
|
28 |
-
def score_hamming(sequence_df, family_idx, data_dir = f"./data/"):
|
29 |
-
|
30 |
-
assert len(set(list(sequence_df["family"]))) == 1 and sequence_df["family"].iloc[0] == family_idx
|
31 |
-
|
32 |
-
#load dataset
|
33 |
-
dataset = ProteinMemmapDataset(
|
34 |
-
msa_memmap_path=f"{data_dir}open_protein_set_memmap.dat",
|
35 |
-
msa_memmap_meta_path=f"{data_dir}open_protein_set_memmap_indices.csv",
|
36 |
-
subset_path=f"{data_dir}/cluster_testing_set.txt",
|
37 |
-
sample=False,
|
38 |
-
max_msa_len=-1,
|
39 |
-
reverse=False,
|
40 |
-
seed=0,
|
41 |
-
troubleshoot=False,
|
42 |
-
fim_strategy="multiple_span",
|
43 |
-
always_mask=False,
|
44 |
-
max_position_embeddings=2048,
|
45 |
-
max_seq_position_embeddings=512,
|
46 |
-
add_position_ids="1d",
|
47 |
-
mask_fraction=0.2,
|
48 |
-
max_patches=5,
|
49 |
-
)
|
50 |
-
|
51 |
-
# Select a sample of the dataset to be the input
|
52 |
-
data = dataset[family_idx]
|
53 |
-
tokens = data["input_ids"][None,:]
|
54 |
-
all_context = decode_sequence(tokens[0].cpu().numpy())
|
55 |
-
list_sequences_msa = [reorder_masked_sequence(elem+"<cls>") for elem in all_context.split("<cls>")[1:-1]]
|
56 |
-
|
57 |
-
# sequence_df["hamming"] = pd.Series(dtype=object)
|
58 |
-
sequence_df["min_hamming"] = pd.Series()
|
59 |
-
sequence_df["median_hamming"] = pd.Series()
|
60 |
-
sequence_df["mean_hamming"] = pd.Series()
|
61 |
-
sequence_df["std_hamming"] = pd.Series()
|
62 |
-
|
63 |
-
for seq in tqdm(list(sequence_df["sequence"])):
|
64 |
-
|
65 |
-
all_hamming = []
|
66 |
-
for ctx_seq in list_sequences_msa:
|
67 |
-
if ctx_seq == seq:
|
68 |
-
continue
|
69 |
-
else:
|
70 |
-
hamming, _, _ = align_sequences(ctx_seq, seq , print_alignments=False)
|
71 |
-
all_hamming.append(hamming)
|
72 |
-
|
73 |
-
# sequence_df.loc[sequence_df["sequence"] == seq, "hamming"] = [all_hamming]
|
74 |
-
sequence_df.loc[sequence_df["sequence"] == seq, "min_hamming"] = np.min(all_hamming)
|
75 |
-
sequence_df.loc[sequence_df["sequence"] == seq, "median_hamming"] = np.median(all_hamming)
|
76 |
-
sequence_df.loc[sequence_df["sequence"] == seq, "mean_hamming"] = np.mean(all_hamming)
|
77 |
-
sequence_df.loc[sequence_df["sequence"] == seq, "std_hamming"] = np.std(all_hamming)
|
78 |
-
|
79 |
-
return sequence_df
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
protxlstm/applications/generation_utils/score_hmmer.py
DELETED
@@ -1,102 +0,0 @@
|
|
1 |
-
import string
|
2 |
-
from Bio import SeqIO
|
3 |
-
import pyhmmer
|
4 |
-
from tqdm import tqdm
|
5 |
-
|
6 |
-
alphabet = pyhmmer.easel.Alphabet.amino()
|
7 |
-
|
8 |
-
# This is an efficient way to delete lowercase characters and insertion characters from a string
|
9 |
-
deletekeys = dict.fromkeys(string.ascii_lowercase)
|
10 |
-
deletekeys["."] = None
|
11 |
-
deletekeys["*"] = None
|
12 |
-
translation = str.maketrans(deletekeys)
|
13 |
-
|
14 |
-
def remove_insertions(sequence: str) -> str:
|
15 |
-
""" Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """
|
16 |
-
return sequence.translate(translation)
|
17 |
-
|
18 |
-
def read_msa(filename: str):
|
19 |
-
""" Reads the sequences from an MSA file, automatically removes insertions."""
|
20 |
-
return [(record.description, remove_insertions(str(record.seq))) for record in SeqIO.parse(filename, "fasta")]
|
21 |
-
|
22 |
-
def read_msa_unaligned(filename: str):
|
23 |
-
""" Reads the sequences from an MSA file, removes only . - and * characters."""
|
24 |
-
return [(record.description, str(record.seq).replace(".","").replace("-","").replace("*","").upper()) for record in SeqIO.parse(filename, "fasta")]
|
25 |
-
|
26 |
-
def check_msa(msa):
|
27 |
-
""" Checks if there are any repeated sequences in the MSA"""
|
28 |
-
seqs = set()
|
29 |
-
for el in msa:
|
30 |
-
seqs.add(el[1])
|
31 |
-
assert len(seqs) == len(msa), "There are repeated sequences in the MSA"
|
32 |
-
|
33 |
-
def make_hmm_from_a3m_msa(msa_filepath, hmm_filename=None):
|
34 |
-
# Load MSA from a3m
|
35 |
-
msa_tup = read_msa(msa_filepath)
|
36 |
-
# check_msa(msa_tup)
|
37 |
-
# Create digitized MSA block
|
38 |
-
all_seqs = [pyhmmer.easel.TextSequence(name=str(i).encode("utf-8"), sequence=seq) for i, (idz, seq) in enumerate(msa_tup)]
|
39 |
-
msa = pyhmmer.easel.TextMSA(name=b"msa", sequences=all_seqs)
|
40 |
-
msa = msa.digitize(alphabet)
|
41 |
-
# Fit HMM
|
42 |
-
builder = pyhmmer.plan7.Builder(alphabet)
|
43 |
-
background = pyhmmer.plan7.Background(alphabet)
|
44 |
-
hmm, _, _ = builder.build_msa(msa, background)
|
45 |
-
if hmm_filename is not None:
|
46 |
-
with open(f"{hmm_filename}.hmm", "wb") as output_file:
|
47 |
-
hmm.write(output_file)
|
48 |
-
return hmm
|
49 |
-
|
50 |
-
def align_and_score_sequences_in_a3m_with_hmm(hmm, sequences_path=None, sequences_list=None):
|
51 |
-
if sequences_list is not None:
|
52 |
-
msa = sequences_list
|
53 |
-
all_seqs = [pyhmmer.easel.TextSequence(name=str(i).encode("utf-8"), sequence=seq) for i, seq in enumerate(sequences_list)]
|
54 |
-
elif sequences_path is not None:
|
55 |
-
# Load sequences from a3m
|
56 |
-
msa = read_msa_unaligned(sequences_path)
|
57 |
-
all_seqs = [pyhmmer.easel.TextSequence(name=str(i).encode("utf-8"), sequence=seq) for i, (idz, seq) in enumerate(msa)]
|
58 |
-
else:
|
59 |
-
raise NotImplementedError("Missing sequences to align/score")
|
60 |
-
# Create digitized Sequence block
|
61 |
-
seq_block = pyhmmer.easel.TextSequenceBlock(all_seqs)
|
62 |
-
seq_block = seq_block.digitize(alphabet)
|
63 |
-
# Get all hits from the hmm
|
64 |
-
background = pyhmmer.plan7.Background(alphabet)
|
65 |
-
pipeline = pyhmmer.plan7.Pipeline(alphabet, background=background, bias_filter=False, F1=1.0, F2=1.0, F3=1.0)
|
66 |
-
hits = pipeline.search_hmm(hmm, seq_block)
|
67 |
-
if len(hits) != len(msa):
|
68 |
-
print(f"Number of hits: {len(hits)} is different from the number of sequences in the MSA: {len(msa)}")
|
69 |
-
# Extract hits
|
70 |
-
all_hits = {}
|
71 |
-
for hit in hits:
|
72 |
-
idz, score, evalue = hit.name, hit.score, hit.evalue
|
73 |
-
i = int(idz.decode("utf-8"))
|
74 |
-
seq = msa[i][1] if sequences_path is not None else sequences_list[i]
|
75 |
-
all_hits[seq] = {"score": score, "evalue": evalue}
|
76 |
-
return all_hits
|
77 |
-
|
78 |
-
|
79 |
-
def score_hmmer(sequence_df, family_idx, data_dir = f"./data/"):
|
80 |
-
|
81 |
-
assert len(set(list(sequence_df["family"]))) == 1 and sequence_df["family"].iloc[0] == family_idx
|
82 |
-
|
83 |
-
family_id = sequence_df["family_id"].iloc[0]
|
84 |
-
msa_filepath = f"{data_dir}/a3m_files/{family_id}/a3m/uniclust30.a3m"
|
85 |
-
try:
|
86 |
-
hmm = make_hmm_from_a3m_msa(msa_filepath)
|
87 |
-
except:
|
88 |
-
raise Exception(f"Missing MSA of family {family_id}")
|
89 |
-
|
90 |
-
# align sequences
|
91 |
-
sequences = list(sequence_df["sequence"])
|
92 |
-
scores = align_and_score_sequences_in_a3m_with_hmm(hmm, sequences_list=sequences)
|
93 |
-
|
94 |
-
# save the scores associated to each sequence in the main df in the columns "score" and "evalue"
|
95 |
-
for seq in tqdm(sequences):
|
96 |
-
sequence_df.loc[sequence_df["sequence"] == seq, "score_gen"] = scores[seq]["score"] if seq in scores.keys() else 0
|
97 |
-
sequence_df.loc[sequence_df["sequence"] == seq, "evalue_gen"] = scores[seq]["evalue"] if seq in scores.keys() else 1
|
98 |
-
|
99 |
-
return sequence_df
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
protxlstm/applications/generation_utils/score_structure.py
DELETED
@@ -1,55 +0,0 @@
|
|
1 |
-
from Bio.PDB import PDBParser
|
2 |
-
import torch
|
3 |
-
from tqdm import tqdm
|
4 |
-
from transformers import EsmForProteinFolding
|
5 |
-
|
6 |
-
from protxlstm.utils import MASK_TO_ID
|
7 |
-
|
8 |
-
|
9 |
-
pdb_parser = PDBParser()
|
10 |
-
|
11 |
-
|
12 |
-
def compute_structure(seq, model):
|
13 |
-
def keep_sequence(seq, l):
|
14 |
-
if len(seq) > l:
|
15 |
-
return False
|
16 |
-
for mm in list(MASK_TO_ID.keys())+["<eos>", "<pad>", "<unk>", "<mask>", "<cls>", "<null_1>", "." , "-"]:
|
17 |
-
if mm in seq:
|
18 |
-
return False
|
19 |
-
return True
|
20 |
-
keep = keep_sequence(seq, l=750)
|
21 |
-
if keep:
|
22 |
-
with torch.no_grad():
|
23 |
-
output = model.infer([seq])
|
24 |
-
# pdb = model.output_to_pdb(output)
|
25 |
-
ptm = output["ptm"].item()
|
26 |
-
pae = output["predicted_aligned_error"].cpu().numpy()
|
27 |
-
mean_plddt = ((output["plddt"] * output["atom37_atom_exists"]).sum(dim=(1, 2)) / output["atom37_atom_exists"].sum(dim=(1, 2))).item()
|
28 |
-
pos_plddt = ((output["plddt"] * output["atom37_atom_exists"]).sum(dim=(2,)) / output["atom37_atom_exists"].sum(dim=(2,))).cpu().numpy()
|
29 |
-
else:
|
30 |
-
print(f"Sequence is invalid.")
|
31 |
-
ptm, pae, mean_plddt, pos_plddt = 0, 0 ,0 , 0
|
32 |
-
return ptm, pae, mean_plddt, pos_plddt
|
33 |
-
|
34 |
-
|
35 |
-
def score_structure(sequence_df, family_idx):
|
36 |
-
|
37 |
-
assert len(set(list(sequence_df["family"]))) == 1 and sequence_df["family"].iloc[0] == family_idx
|
38 |
-
|
39 |
-
device="cuda:0"
|
40 |
-
|
41 |
-
# Import the folding model
|
42 |
-
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True)
|
43 |
-
|
44 |
-
model = model.cuda(device)
|
45 |
-
model.esm = model.esm.half()
|
46 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
47 |
-
|
48 |
-
sequences = list(sequence_df["sequence"])
|
49 |
-
for seq in tqdm(sequences):
|
50 |
-
|
51 |
-
ptm, pae, mean_plddt, pos_plddt = compute_structure(seq, model)
|
52 |
-
sequence_df.loc[sequence_df["sequence"] == seq, "ptm"] = ptm
|
53 |
-
sequence_df.loc[sequence_df["sequence"] == seq, "mean_plddt"] = mean_plddt
|
54 |
-
|
55 |
-
return sequence_df
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
protxlstm/applications/sample_sequences.py
DELETED
@@ -1,200 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from tqdm import tqdm
|
3 |
-
import pickle
|
4 |
-
import os
|
5 |
-
import argparse
|
6 |
-
import json
|
7 |
-
|
8 |
-
from protxlstm.dataloaders import ProteinMemmapDataset
|
9 |
-
from protxlstm.generation import generate_sequence
|
10 |
-
from protxlstm.utils import (
|
11 |
-
AA_TO_ID,
|
12 |
-
load_model,
|
13 |
-
)
|
14 |
-
from protxlstm.models.xlstm import xLSTMLMHeadModel
|
15 |
-
from protxlstm.models.mamba import MambaLMHeadModelwithPosids
|
16 |
-
|
17 |
-
|
18 |
-
def sample_sequences(dataset,
|
19 |
-
model,
|
20 |
-
family_idx,
|
21 |
-
params,
|
22 |
-
n_samples_per_family,
|
23 |
-
max_length=1000,
|
24 |
-
chunk_chunk_size=2**15,
|
25 |
-
save_path=None,
|
26 |
-
device="cuda:0"):
|
27 |
-
"""
|
28 |
-
Function to sample sequences from the model. Given a dataset, a list of families (their indexes in the dataset)
|
29 |
-
and a set of generating parameters, it generates `n_samples_per_family` sequences for each family and each parameter set.
|
30 |
-
The function returns a dictionary with the following structure:
|
31 |
-
gen_seqs = {family_idx: {parameters: {sequence: perplexity}}}
|
32 |
-
The parameters are in a list of tuples with the following structure:
|
33 |
-
parameters_list = [(nr_seqs_ctx, temperature, top_k, top_p)]
|
34 |
-
"""
|
35 |
-
gen_seqs = {}
|
36 |
-
gen_seqs[family_idx] = {}
|
37 |
-
gen_seqs[family_idx][params] = {}
|
38 |
-
print(f"Sampling sequences for family {family_idx} and parameters {params}.")
|
39 |
-
|
40 |
-
n_seqs_ctx , temperature, top_k, top_p = params
|
41 |
-
for _ in tqdm(range(n_samples_per_family)):
|
42 |
-
# Sample the dataset to get the input
|
43 |
-
data = dataset[family_idx]
|
44 |
-
tokens = data["input_ids"][None,:].to(device)
|
45 |
-
pos_ids = data["position_ids"][None,:].to(device)
|
46 |
-
|
47 |
-
start_seqs = torch.argwhere(tokens[0]==0)[:,0].cpu().numpy()
|
48 |
-
|
49 |
-
n_seqs_ctx = len(start_seqs) if len(start_seqs) < n_seqs_ctx else n_seqs_ctx
|
50 |
-
L = start_seqs[n_seqs_ctx]+1
|
51 |
-
context_tokens = tokens[:,:L]
|
52 |
-
context_pos_ids = pos_ids[:,:L]
|
53 |
-
is_fim={}
|
54 |
-
|
55 |
-
# Generate the new sequence
|
56 |
-
output = generate_sequence(model,
|
57 |
-
context_tokens,
|
58 |
-
position_ids=context_pos_ids,
|
59 |
-
is_fim=is_fim,
|
60 |
-
max_length=(L+max_length),
|
61 |
-
temperature=temperature,
|
62 |
-
top_k=top_k,
|
63 |
-
top_p=top_p,
|
64 |
-
return_dict_in_generate=True,
|
65 |
-
output_scores=True,
|
66 |
-
eos_token_id=torch.tensor([AA_TO_ID["<cls>"]]).to(device),
|
67 |
-
chunk_chunk_size=chunk_chunk_size,
|
68 |
-
device=device)
|
69 |
-
|
70 |
-
# Get the perplexity of the generated sequence
|
71 |
-
output_seq = output["generated"]
|
72 |
-
loss = torch.nn.functional.cross_entropy(torch.from_numpy(output["scores"]).permute(0, 2, 1),
|
73 |
-
torch.from_numpy(output["generated_tokens"][0][None,:]))
|
74 |
-
|
75 |
-
# save only sequences with length < max_length
|
76 |
-
if len(output_seq[0]) < max_length:
|
77 |
-
|
78 |
-
gen_seqs[family_idx][params][output_seq[0]] = {"perplexity": torch.exp(loss).item()}
|
79 |
-
|
80 |
-
if save_path is not None:
|
81 |
-
if not os.path.exists("evaluation/generation/generated_sequences"):
|
82 |
-
os.mkdir("evaluation/generation/generated_sequences")
|
83 |
-
if not os.path.exists(save_path):
|
84 |
-
os.mkdir(save_path)
|
85 |
-
with open(f'{save_path}/{family_idx}_{params}_{n_samples_per_family}', "wb") as f:
|
86 |
-
pickle.dump(gen_seqs, f)
|
87 |
-
print(f"Sequences saved for family {family_idx} and parameters {params}")
|
88 |
-
|
89 |
-
return gen_seqs
|
90 |
-
|
91 |
-
def generate_sequences(model_name,
|
92 |
-
checkpoint,
|
93 |
-
family_idxs=[],
|
94 |
-
parameters_list=[],
|
95 |
-
n_samples_per_family = 100,
|
96 |
-
chunk_size=1024,
|
97 |
-
chunk_chunk_size=2**15,
|
98 |
-
data_dir="data/",
|
99 |
-
device="cuda:0"
|
100 |
-
):
|
101 |
-
|
102 |
-
# Load the test dataset
|
103 |
-
fim_strategy = "multiple_span"
|
104 |
-
mask_fraction = 0.2
|
105 |
-
|
106 |
-
dataset = ProteinMemmapDataset(
|
107 |
-
msa_memmap_path=f"{data_dir}open_protein_set_memmap.dat",
|
108 |
-
msa_memmap_meta_path=f"{data_dir}open_protein_set_memmap_indices.csv",
|
109 |
-
subset_path=f"{data_dir}cluster_testing_set.txt",
|
110 |
-
sample=False,
|
111 |
-
max_msa_len=-1,
|
112 |
-
reverse=False,
|
113 |
-
seed=0,
|
114 |
-
troubleshoot=False,
|
115 |
-
fim_strategy=fim_strategy,
|
116 |
-
always_mask=False,
|
117 |
-
max_position_embeddings=2048,
|
118 |
-
max_seq_position_embeddings=512,
|
119 |
-
add_position_ids="1d",
|
120 |
-
mask_fraction=mask_fraction
|
121 |
-
)
|
122 |
-
|
123 |
-
if model_name == "xlstm":
|
124 |
-
model_class = xLSTMLMHeadModel
|
125 |
-
elif model_name == "mamba":
|
126 |
-
model_class = MambaLMHeadModelwithPosids
|
127 |
-
|
128 |
-
save_path = f"evaluation/generation/generated_sequences/{checkpoint.split('/')[-1]}"
|
129 |
-
|
130 |
-
if model_name == "xlstm":
|
131 |
-
config_update_kwargs = {
|
132 |
-
"mlstm_backend": "chunkwise_variable",
|
133 |
-
"mlstm_chunksize": chunk_size,
|
134 |
-
"mlstm_return_last_state": True
|
135 |
-
}
|
136 |
-
else:
|
137 |
-
config_update_kwargs = {}
|
138 |
-
|
139 |
-
|
140 |
-
#load the model
|
141 |
-
model = load_model(checkpoint,
|
142 |
-
model_class=model_class,
|
143 |
-
device=device,
|
144 |
-
dtype=torch.bfloat16,
|
145 |
-
**config_update_kwargs,
|
146 |
-
)
|
147 |
-
model = model.eval()
|
148 |
-
print("Model loaded.")
|
149 |
-
|
150 |
-
for family_idx in family_idxs:
|
151 |
-
for params in parameters_list:
|
152 |
-
params = tuple(params)
|
153 |
-
if not os.path.exists(f'{save_path}/{family_idx}_{params}_{n_samples_per_family}'):
|
154 |
-
gen_seqs = sample_sequences(
|
155 |
-
dataset=dataset,
|
156 |
-
model=model,
|
157 |
-
family_idx=family_idx,
|
158 |
-
params=params,
|
159 |
-
n_samples_per_family=n_samples_per_family,
|
160 |
-
chunk_chunk_size=chunk_chunk_size,
|
161 |
-
save_path=save_path,
|
162 |
-
device=device)
|
163 |
-
|
164 |
-
print(f"Sampled {len(gen_seqs[family_idx][params])} valid sequences.")
|
165 |
-
else:
|
166 |
-
print(f"Sequences for family {family_idx} and parameters {params} already exist.")
|
167 |
-
|
168 |
-
|
169 |
-
if __name__ == "__main__":
|
170 |
-
|
171 |
-
parser = argparse.ArgumentParser(
|
172 |
-
description="Generate sequences."
|
173 |
-
)
|
174 |
-
parser.add_argument("--model_name", type=str, help="Either 'xlstm' or 'mamba'.")
|
175 |
-
parser.add_argument("--checkpoint", type=str, help="Path to model checkpoint.")
|
176 |
-
parser.add_argument("--family_idxs", type=str, help="List of family indices.")
|
177 |
-
parser.add_argument("--parameters_list", type=str, help="List of sampling parameters.")
|
178 |
-
parser.add_argument("--n_samples_per_family", type=int, default=100, help="Number of sequences to sample per family and parameter set.")
|
179 |
-
parser.add_argument("--chunk_size", type=int, default=1024, help="Chunk size for xLSTM context encoding.")
|
180 |
-
parser.add_argument("--chunk_chunk_size", type=int, default=2*15, help="Length of context sequence part processed at once.")
|
181 |
-
parser.add_argument("--data_dir", type=str, default="data/", help="Path to dataset.")
|
182 |
-
parser.add_argument("--device", type=str, default="cuda:0", help="Device.")
|
183 |
-
|
184 |
-
args = parser.parse_args()
|
185 |
-
|
186 |
-
family_idxs = json.loads(args.family_idxs)
|
187 |
-
parameters_list = json.loads(args.parameters_list)
|
188 |
-
|
189 |
-
# Run sequence generation
|
190 |
-
generate_sequences(
|
191 |
-
model_name=args.model_name,
|
192 |
-
checkpoint=args.checkpoint,
|
193 |
-
family_idxs=family_idxs,
|
194 |
-
parameters_list=parameters_list,
|
195 |
-
n_samples_per_family=args.n_samples_per_family,
|
196 |
-
chunk_size=args.chunk_size,
|
197 |
-
chunk_chunk_size=args.chunk_chunk_size,
|
198 |
-
data_dir=args.data_dir,
|
199 |
-
device=args.device,
|
200 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
protxlstm/applications/score_sequences.py
DELETED
@@ -1,58 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import os
|
3 |
-
import pickle
|
4 |
-
|
5 |
-
from generation_utils.create_sequence_df import create_sequence_df
|
6 |
-
from generation_utils.score_hamming import score_hamming
|
7 |
-
from generation_utils.score_hmmer import score_hmmer
|
8 |
-
from generation_utils.score_structure import score_structure
|
9 |
-
|
10 |
-
|
11 |
-
def score_sequences(model_name,
|
12 |
-
family_idx,
|
13 |
-
num_sequences = 100,
|
14 |
-
data_dir = "data/"):
|
15 |
-
|
16 |
-
if os.path.isfile(f"evaluation/generation/evaluations/{model_name}/sequence_df_{family_idx}"):
|
17 |
-
with open(f"evaluation/generation/evaluations/{model_name}/sequence_df_{family_idx}", "rb") as f:
|
18 |
-
sequence_df = pickle.load(f)
|
19 |
-
else:
|
20 |
-
sequence_df = create_sequence_df(model_name, family_idx, data_dir = data_dir, num_sequences = num_sequences)
|
21 |
-
if not os.path.exists("evaluation/generation/evaluations/"):
|
22 |
-
os.mkdir("evaluation/generation/evaluations/")
|
23 |
-
if not os.path.exists(f"evaluation/generation/evaluations/{model_name}/"):
|
24 |
-
os.mkdir(f"evaluation/generation/evaluations/{model_name}/")
|
25 |
-
with open(f"evaluation/generation/evaluations/{model_name}/sequence_df_{family_idx}", "wb") as f:
|
26 |
-
pickle.dump(sequence_df, f)
|
27 |
-
|
28 |
-
if not "min_hamming" in sequence_df.columns:
|
29 |
-
sequence_df = score_hamming(sequence_df, family_idx, data_dir)
|
30 |
-
with open(f"evaluation/generation/evaluations/{model_name}/sequence_df_{family_idx}", "wb") as f:
|
31 |
-
pickle.dump(sequence_df, f)
|
32 |
-
|
33 |
-
if not "score_gen" in sequence_df.columns:
|
34 |
-
sequence_df = score_hmmer(sequence_df, family_idx, data_dir)
|
35 |
-
with open(f"evaluation/generation/evaluations/{model_name}/sequence_df_{family_idx}", "wb") as f:
|
36 |
-
pickle.dump(sequence_df, f)
|
37 |
-
|
38 |
-
if not "ptm" in sequence_df.columns:
|
39 |
-
sequence_df = score_structure(sequence_df, family_idx)
|
40 |
-
with open(f"evaluation/generation/evaluations/{model_name}/sequence_df_{family_idx}", "wb") as f:
|
41 |
-
pickle.dump(sequence_df, f)
|
42 |
-
|
43 |
-
return sequence_df
|
44 |
-
|
45 |
-
|
46 |
-
if __name__ == "__main__":
|
47 |
-
|
48 |
-
parser = argparse.ArgumentParser(
|
49 |
-
description="Generate sequences."
|
50 |
-
)
|
51 |
-
parser.add_argument("--model_name", type=str, help="Either 'xlstm' or 'mamba'.")
|
52 |
-
parser.add_argument("--family_idx", type=int, help="Family index.")
|
53 |
-
parser.add_argument("--num_sequences", type=int, default=100, help="Number of sequences.")
|
54 |
-
parser.add_argument("--data_dir", type=str, default="./data/", help="Path to dataset.")
|
55 |
-
|
56 |
-
args = parser.parse_args()
|
57 |
-
|
58 |
-
sequence_df = score_sequences(args.model_name, args.family_idx, args.num_sequences, args.data_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
protxlstm/data.py
DELETED
@@ -1,60 +0,0 @@
|
|
1 |
-
import csv
|
2 |
-
import os
|
3 |
-
|
4 |
-
import numpy as np
|
5 |
-
from tqdm import tqdm
|
6 |
-
|
7 |
-
from protxlstm.utils import load_sequences_from_msa_file, tokenizer
|
8 |
-
|
9 |
-
def process_msa(msa_item):
|
10 |
-
msa_name, msa_path = msa_item
|
11 |
-
# Load an a3m file with all the context sequences
|
12 |
-
msa = load_sequences_from_msa_file(msa_path)
|
13 |
-
# Tokenize the sequences and concatenate them into a single array
|
14 |
-
tokens = tokenizer(msa, concatenate=True)
|
15 |
-
tokens = tokens.numpy()[0]
|
16 |
-
return msa_name, tokens
|
17 |
-
|
18 |
-
def main(data_dir, output_dir):
|
19 |
-
msa_paths = {k: os.path.join(data_dir, k, 'a3m/uniclust30.a3m') for k in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, k))}
|
20 |
-
msa_items = list(msa_paths.items())
|
21 |
-
|
22 |
-
dataset_dictionary = {}
|
23 |
-
total_length = 0
|
24 |
-
|
25 |
-
# First pass: calculate total length of all concatenated arrays
|
26 |
-
for item in tqdm(msa_items):
|
27 |
-
try:
|
28 |
-
k, v = process_msa(item)
|
29 |
-
dataset_dictionary[k] = v
|
30 |
-
total_length += len(v)
|
31 |
-
except:
|
32 |
-
print(f"Error processing {item}")
|
33 |
-
|
34 |
-
# Initialize the memmap array with the calculated total length
|
35 |
-
memmap_path = os.path.join(output_dir, 'open_protein_set_memmap.dat')
|
36 |
-
concatenated_array = np.memmap(memmap_path, dtype='int8', mode='w+', shape=(total_length,))
|
37 |
-
|
38 |
-
with open(f'{output_dir}/open_protein_set_memmap_indices.csv', 'w', newline='') as csvfile:
|
39 |
-
csvwriter = csv.writer(csvfile)
|
40 |
-
|
41 |
-
csvwriter.writerow(['msa_id', 'Start', 'End'])
|
42 |
-
|
43 |
-
start_index = 0
|
44 |
-
for key, array in dataset_dictionary.items():
|
45 |
-
end_index = start_index + len(array) - 1
|
46 |
-
concatenated_array[start_index:end_index + 1] = array # Write to memmap
|
47 |
-
csvwriter.writerow([key, start_index, end_index])
|
48 |
-
start_index = end_index + 1
|
49 |
-
|
50 |
-
# Ensure the data is written to disk
|
51 |
-
concatenated_array.flush()
|
52 |
-
|
53 |
-
|
54 |
-
if __name__ == "__main__":
|
55 |
-
data_dir = 'data/a3m_files'
|
56 |
-
output_dir = 'data/'
|
57 |
-
main(data_dir, output_dir)
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
protxlstm/dataloaders.py
DELETED
@@ -1,249 +0,0 @@
|
|
1 |
-
# Original code from ProtMamba under Apache License 2.0.
|
2 |
-
#
|
3 |
-
# Modifications made by Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen
|
4 |
-
# - Uniclust30_Dataset renamed to ProteinMemmapDataset
|
5 |
-
# - Dataset input file format changed for more efficient dataloading
|
6 |
-
# - Option to use only a subset
|
7 |
-
# - DataCollatorForUniclust30Dataset renamed to ProteinDataCollator
|
8 |
-
# - Add sequence padding
|
9 |
-
|
10 |
-
import numpy as np
|
11 |
-
import pandas as pd
|
12 |
-
import torch
|
13 |
-
from torch.utils.data import DataLoader, Dataset
|
14 |
-
from typing import Dict, Optional, Sequence
|
15 |
-
|
16 |
-
from protxlstm.fim import MultipleSpanFIM, NoFIM, SingleSpanFIM
|
17 |
-
from protxlstm.utils import AA_TO_ID
|
18 |
-
|
19 |
-
|
20 |
-
# Make dataset
|
21 |
-
class ProteinMemmapDataset(Dataset):
|
22 |
-
"""
|
23 |
-
ProteinMemmapDataset is a PyTorch Dataset class for handling memory-mapped datasets of protein multiple sequence alignments (MSAs).
|
24 |
-
|
25 |
-
This class imports MSA data stored in memmap format and associated metadata CSVs. It supports flexible
|
26 |
-
data sampling strategies and inpainting methods for sequence manipulation and training purposes.
|
27 |
-
|
28 |
-
Args:
|
29 |
-
msa_memmap_path (str): Path to the memory-mapped file containing the MSA clusters.
|
30 |
-
msa_memmap_meta_path (str): Path to the CSV file with metadata linking MSA Cluster IDs and indices in the memmap array.
|
31 |
-
subset_path (str, optional): Path to a CSV file specifying a subset of cluster IDs to use.
|
32 |
-
sample (bool, optional): If True, randomly samples sequences from each cluster; otherwise, loads all sequences and shuffles them.
|
33 |
-
max_msa_len (int, optional): Maximum length of the MSA sequences to include. Defaults to -1 (no limit).
|
34 |
-
reverse (bool, optional): If True, reverses sequences with a probability of 0.5 and moves the last token to the front.
|
35 |
-
seed (int, optional): Random seed for reproducibility. Defaults to 42.
|
36 |
-
troubleshoot (bool, optional): If True, prints debugging information. Defaults to False.
|
37 |
-
fim_strategy (str, optional): Strategy for inpainting ("no-scramble", "one_span", or "multiple_span").
|
38 |
-
max_patches (int, optional): Number of patches for inpainting. Used when fim_strategy is "multiple_span".
|
39 |
-
mask_fraction (float, optional): Fraction of the patches to mask. Used when fim_strategy is "multiple_span".
|
40 |
-
always_mask (bool, optional): If True, ensures masking is applied in the inpainting process.
|
41 |
-
max_position_embeddings (int, optional): Maximum position embeddings. Defaults to 2048.
|
42 |
-
max_seq_position_embeddings (int, optional): Maximum sequence position embeddings for 2D positional IDs. Defaults to 512.
|
43 |
-
add_position_ids (str, optional): Type of position IDs to add ("none", "1d", or "2d"). Defaults to "1d".
|
44 |
-
"""
|
45 |
-
|
46 |
-
_FIM = {"no-scramble": NoFIM, "one_span": SingleSpanFIM, "multiple_span": MultipleSpanFIM}
|
47 |
-
_POSIDS = {"none", "1d", "2d"}
|
48 |
-
|
49 |
-
def __init__(self,
|
50 |
-
msa_memmap_path=None,
|
51 |
-
msa_memmap_meta_path=None,
|
52 |
-
subset_path=None,
|
53 |
-
sample=False,
|
54 |
-
max_msa_len=-1,
|
55 |
-
reverse=False,
|
56 |
-
seed=42,
|
57 |
-
troubleshoot=False,
|
58 |
-
fim_strategy="no-scramble",
|
59 |
-
max_patches=5,
|
60 |
-
mask_fraction=0.2,
|
61 |
-
always_mask=False,
|
62 |
-
max_position_embeddings=2048,
|
63 |
-
max_seq_position_embeddings=512,
|
64 |
-
add_position_ids="1d", ):
|
65 |
-
|
66 |
-
np.random.seed(seed)
|
67 |
-
|
68 |
-
if msa_memmap_path:
|
69 |
-
self.dataset = np.memmap(msa_memmap_path, dtype=np.int8, mode='r')
|
70 |
-
self.dataset_meta = pd.read_csv(msa_memmap_meta_path)
|
71 |
-
if subset_path:
|
72 |
-
subset_ids = pd.read_csv(subset_path, header=None, names=['ID'])['ID'].tolist()
|
73 |
-
self.dataset_meta = self.dataset_meta[self.dataset_meta['msa_id'].isin(subset_ids)]
|
74 |
-
else:
|
75 |
-
self.dataset = None
|
76 |
-
|
77 |
-
self.sample = sample
|
78 |
-
self.max_msa_len = max_msa_len
|
79 |
-
self.reverse = reverse
|
80 |
-
self.fim_strategy = fim_strategy
|
81 |
-
if fim_strategy in ProteinMemmapDataset._FIM:
|
82 |
-
self.fim = ProteinMemmapDataset._FIM[fim_strategy](max_patches=max_patches,
|
83 |
-
mask_fraction=mask_fraction,
|
84 |
-
always_mask=always_mask,
|
85 |
-
add_position_ids=add_position_ids != "none",
|
86 |
-
troubleshoot=troubleshoot)
|
87 |
-
else:
|
88 |
-
raise ValueError(f'Fill in the middle stragy "{fim_strategy}" not recognized.')
|
89 |
-
|
90 |
-
self.max_position_embeddings = max_position_embeddings
|
91 |
-
self.max_seq_position_embeddings = max_seq_position_embeddings
|
92 |
-
self.add_position_ids = add_position_ids
|
93 |
-
|
94 |
-
self.troubleshoot = troubleshoot
|
95 |
-
|
96 |
-
def __len__(self):
|
97 |
-
# meta dataframe has one row for each MSA cluster
|
98 |
-
return len(self.dataset_meta)
|
99 |
-
|
100 |
-
def __getitem__(self, idx):
|
101 |
-
# get all the sequences in the cluster
|
102 |
-
sequences = self.get_sequences(idx)
|
103 |
-
# get total number of sequences in the cluster and choose how many to sample
|
104 |
-
orig_num_sequences = len(self.get_index_start_of_sequences(sequences))
|
105 |
-
num_sequences = np.random.randint(1, orig_num_sequences + 1) if self.sample else orig_num_sequences
|
106 |
-
# sample the sequences
|
107 |
-
sequences, position_ids = self.sample_sequences(sequences, num_sequences)
|
108 |
-
# with probability 0.5, reverse the sequences and move the last token to the front
|
109 |
-
sequences, position_ids = self.reverse_sequences(sequences, position_ids) if (
|
110 |
-
self.reverse and np.random.rand() > 0.5) else sequences, position_ids
|
111 |
-
# limit the length of the MSA
|
112 |
-
sequences = sequences[:self.max_msa_len] if self.max_msa_len > 0 else sequences
|
113 |
-
if self.add_position_ids != "none":
|
114 |
-
position_ids = position_ids[:self.max_msa_len] if self.max_msa_len > 0 else position_ids
|
115 |
-
# convert to tensor
|
116 |
-
sequences = torch.asarray(sequences, dtype=torch.int64)
|
117 |
-
position_ids = torch.asarray(position_ids, dtype=torch.int64).clamp(0,
|
118 |
-
self.max_position_embeddings - 1) if self.add_position_ids!="none" else None
|
119 |
-
|
120 |
-
if self.troubleshoot:
|
121 |
-
print(
|
122 |
-
f"Cluster {idx} has {orig_num_sequences} sequences, of which {num_sequences} sampled now. Total MSA length: {len(sequences)}")
|
123 |
-
if self.add_position_ids == "1d":
|
124 |
-
return dict(input_ids=sequences, position_ids=position_ids, labels=sequences)
|
125 |
-
if self.add_position_ids == "2d":
|
126 |
-
seq_position_ids = (sequences == AA_TO_ID["<cls>"]).int().cumsum(-1).clamp(0,
|
127 |
-
self.max_seq_position_embeddings - 1).contiguous()
|
128 |
-
return dict(input_ids=sequences, position_ids=position_ids, seq_position_ids=seq_position_ids,
|
129 |
-
labels=sequences)
|
130 |
-
return dict(input_ids=sequences, labels=sequences)
|
131 |
-
|
132 |
-
def get_msa_id(self, idx):
|
133 |
-
"""Get the MSA ID in the cluster with index `idx`."""
|
134 |
-
cluster_meta = self.dataset_meta.iloc[idx]
|
135 |
-
return cluster_meta.msa_id
|
136 |
-
|
137 |
-
def get_idx_from_msa_id(self, msa_id):
|
138 |
-
"""Get `idx` with the MSA ID"""
|
139 |
-
return self.dataset_meta[self.dataset_meta.msa_id == msa_id].index[0]
|
140 |
-
|
141 |
-
def get_sequences(self, idx):
|
142 |
-
"""Get the sequences in the cluster with index `idx`."""
|
143 |
-
cluster_meta = self.dataset_meta.iloc[idx]
|
144 |
-
sequences = self.dataset[cluster_meta.Start : cluster_meta.End]
|
145 |
-
return sequences
|
146 |
-
|
147 |
-
def get_index_start_of_sequences(self, sequences):
|
148 |
-
"""Get the positions of the start of each sequence in the cluster."""
|
149 |
-
return np.where(sequences == 0)[0]
|
150 |
-
|
151 |
-
def reverse_sequences(self, sequence, position_ids=None):
|
152 |
-
"""Reverse the sequences and move the last token to the front."""
|
153 |
-
sequence = sequence[::-1]
|
154 |
-
if position_ids is not None:
|
155 |
-
position_ids = position_ids[::-1]
|
156 |
-
return np.concatenate([sequence[-1:], sequence[:-1]]), np.concatenate(
|
157 |
-
[position_ids[-1:], position_ids[:-1]]) if position_ids is not None else None
|
158 |
-
|
159 |
-
def sample_sequences(self, sequences, num_sequences, shuffle=True):
|
160 |
-
"""Sample `num_sequences` from the sequences in the cluster."""
|
161 |
-
L = len(sequences)
|
162 |
-
# get the indexes of the start of each sequence
|
163 |
-
inds = self.get_index_start_of_sequences(sequences)
|
164 |
-
# check that there are sequences in the cluster and that there are enough of them
|
165 |
-
assert len(inds) > 0, "No sequences found in cluster."
|
166 |
-
assert len(inds) >= num_sequences, "Not enough sequences in cluster."
|
167 |
-
# sample n_sequences randomly from the sequences
|
168 |
-
if shuffle:
|
169 |
-
which_seqs = np.random.choice(np.arange(len(inds)), num_sequences, replace=False)
|
170 |
-
else:
|
171 |
-
which_seqs = np.arange(len(inds))[-num_sequences:]
|
172 |
-
# get the tuples of start and end indexes of the sequences
|
173 |
-
tuples = [(inds[i], inds[i + 1]) if i < len(inds) - 1 else (inds[i], L) for i in which_seqs]
|
174 |
-
if self.troubleshoot:
|
175 |
-
print(f"Sampled sequences: {tuples}")
|
176 |
-
# concatenate the sequences
|
177 |
-
sequences, position_ids = self.fim.apply(sequences, tuples)
|
178 |
-
return sequences, position_ids
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
def make_dataloader(dataset):
|
183 |
-
"""Basic function to make a dataloader.
|
184 |
-
"""
|
185 |
-
dataloader = DataLoader(dataset)
|
186 |
-
return dataloader
|
187 |
-
|
188 |
-
|
189 |
-
class ProteinDataCollator(object):
|
190 |
-
"""
|
191 |
-
Collate examples into a batch, and pad batch to a specified maximum sequence length,
|
192 |
-
or to the longest sequence in the batch if max_sequence_length is None.
|
193 |
-
"""
|
194 |
-
def __init__(self, max_sequence_length: Optional[int] = None):
|
195 |
-
"""
|
196 |
-
Initialize the collator with an optional max_sequence_length.
|
197 |
-
|
198 |
-
Args:
|
199 |
-
max_sequence_length (Optional[int]): The maximum sequence length to pad/truncate to.
|
200 |
-
If None, pad to the longest sequence in the batch.
|
201 |
-
"""
|
202 |
-
self.max_sequence_length = max_sequence_length
|
203 |
-
|
204 |
-
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
205 |
-
|
206 |
-
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "input_ids"))
|
207 |
-
|
208 |
-
longest_seq = max(len(seq) for seq in input_ids)
|
209 |
-
if self.max_sequence_length is None:
|
210 |
-
max_len = longest_seq
|
211 |
-
else:
|
212 |
-
max_len = self.max_sequence_length
|
213 |
-
|
214 |
-
input_ids = self.pad_sequences(input_ids, max_len, padding_value=AA_TO_ID["<pad>"])
|
215 |
-
|
216 |
-
labels = self.pad_sequences(labels, longest_seq, padding_value=AA_TO_ID["<pad>"])
|
217 |
-
labels = self.pad_sequences(labels, max_len, padding_value=-100)
|
218 |
-
|
219 |
-
return_dict = dict(
|
220 |
-
input_ids=input_ids,
|
221 |
-
labels=labels,
|
222 |
-
attention_mask=input_ids.ne(AA_TO_ID["<pad>"])
|
223 |
-
)
|
224 |
-
|
225 |
-
if "position_ids" in instances[0]:
|
226 |
-
|
227 |
-
position_ids = [instance["position_ids"] for instance in instances]
|
228 |
-
position_ids = self.pad_sequences(position_ids, max_len, padding_value=0)
|
229 |
-
return_dict["position_ids"] = position_ids
|
230 |
-
|
231 |
-
if "seq_position_ids" in instances[0]:
|
232 |
-
seq_position_ids = [instance["seq_position_ids"] for instance in instances]
|
233 |
-
seq_position_ids = self.pad_sequences(seq_position_ids, max_len, padding_value=0)
|
234 |
-
return_dict["seq_position_ids"] = seq_position_ids
|
235 |
-
|
236 |
-
return return_dict
|
237 |
-
|
238 |
-
def pad_sequences(self, seqs, max_length, padding_value):
|
239 |
-
# truncate long sequences (redundant, already done in __getitem__, maybe safe to remove)
|
240 |
-
seqs = [seq[:max_length] for seq in seqs]
|
241 |
-
|
242 |
-
# pad to same length
|
243 |
-
seqs = torch.nn.utils.rnn.pad_sequence(seqs, batch_first=True, padding_value=padding_value)
|
244 |
-
|
245 |
-
# pad to max length
|
246 |
-
padding = max_length - seqs.size(1)
|
247 |
-
seqs = torch.nn.functional.pad(seqs, (0, padding), value=padding_value)
|
248 |
-
|
249 |
-
return seqs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
protxlstm/fim.py
DELETED
@@ -1,203 +0,0 @@
|
|
1 |
-
|
2 |
-
# Original code from ProtMamba under Apache License 2.0.
|
3 |
-
|
4 |
-
from protxlstm.utils import MASK_TO_ID, AA_TO_ID
|
5 |
-
import numpy as np
|
6 |
-
|
7 |
-
class AbstractFIM(object):
|
8 |
-
def __init__(self,
|
9 |
-
max_patches=5,
|
10 |
-
mask_fraction=0.2,
|
11 |
-
always_mask=False,
|
12 |
-
mask_tokens=MASK_TO_ID,
|
13 |
-
eos_token=AA_TO_ID["<eos>"],
|
14 |
-
add_position_ids=False,
|
15 |
-
troubleshoot=False):
|
16 |
-
"""
|
17 |
-
This class is designed to concatenate sequences based on different scrambling strategies.
|
18 |
-
It takes a list of sequences, tuples indicating the start and end indices of each sequence,
|
19 |
-
an optional number of patches to sample, and a scrambling strategy as inputs.
|
20 |
-
"""
|
21 |
-
self.troubleshoot = troubleshoot
|
22 |
-
self.max_patches = max_patches
|
23 |
-
self.mask_fraction = mask_fraction
|
24 |
-
self.mask_tokens = mask_tokens
|
25 |
-
assert len(
|
26 |
-
self.mask_tokens) >= self.max_patches, "Number of mask tokens must be bigger than max number of patches."
|
27 |
-
self.eos_token = eos_token
|
28 |
-
self.add_position_ids = add_position_ids
|
29 |
-
self.always_mask = always_mask
|
30 |
-
|
31 |
-
def apply(self, sequences, tuples):
|
32 |
-
"""
|
33 |
-
This function concatenates the sequences scrambling each one according to the scrambling strategy.
|
34 |
-
"""
|
35 |
-
input_ids, position_ids = [], []
|
36 |
-
for t in tuples:
|
37 |
-
seq, pos = self.fim(sequences, t)
|
38 |
-
input_ids.extend(seq)
|
39 |
-
if self.add_position_ids:
|
40 |
-
position_ids.extend(pos)
|
41 |
-
if self.add_position_ids:
|
42 |
-
return input_ids, position_ids
|
43 |
-
return input_ids, None
|
44 |
-
|
45 |
-
def fim(self, sequences, t):
|
46 |
-
"""
|
47 |
-
This function concatenates the sequence's parts based on the scrambling strategy.
|
48 |
-
"""
|
49 |
-
raise NotImplementedError
|
50 |
-
|
51 |
-
|
52 |
-
class NoFIM(AbstractFIM):
|
53 |
-
def __init__(self,
|
54 |
-
max_patches=5,
|
55 |
-
mask_fraction=0.2,
|
56 |
-
always_mask=False,
|
57 |
-
mask_tokens=MASK_TO_ID,
|
58 |
-
eos_token=AA_TO_ID["<eos>"],
|
59 |
-
add_position_ids=False,
|
60 |
-
troubleshoot=False):
|
61 |
-
super().__init__(max_patches, mask_fraction, always_mask, mask_tokens, eos_token, add_position_ids, troubleshoot)
|
62 |
-
|
63 |
-
def fim(self, sequences, t):
|
64 |
-
"""
|
65 |
-
This function keeps the sequence identical without any scrambling.
|
66 |
-
"""
|
67 |
-
if self.add_position_ids:
|
68 |
-
position_ids = np.arange(t[0], t[1]) - t[0]
|
69 |
-
return sequences[t[0]:t[1]], position_ids
|
70 |
-
return sequences[t[0]:t[1]], None
|
71 |
-
|
72 |
-
|
73 |
-
class SingleSpanFIM(AbstractFIM):
|
74 |
-
|
75 |
-
def __init__(self,
|
76 |
-
max_patches=5,
|
77 |
-
mask_fraction=0.2,
|
78 |
-
always_mask=False,
|
79 |
-
mask_tokens=MASK_TO_ID,
|
80 |
-
eos_token=AA_TO_ID["<eos>"],
|
81 |
-
add_position_ids=False,
|
82 |
-
troubleshoot=False):
|
83 |
-
super().__init__(max_patches, mask_fraction, always_mask, mask_tokens, eos_token, add_position_ids, troubleshoot)
|
84 |
-
|
85 |
-
def fim(self, sequences, t):
|
86 |
-
"""
|
87 |
-
This function creates and concatenates parts of the sequences based on the OpenAI scrambling strategy.
|
88 |
-
It randomly selects two indices within the range of the given tuple,
|
89 |
-
splits the sequence into three parts based on these indices, and then concatenates them with the
|
90 |
-
masked patch at the end
|
91 |
-
"""
|
92 |
-
new_tuple = tuple(np.sort(np.random.choice(np.arange(t[0] + 1, t[1]), 2, replace=False)))
|
93 |
-
part1 = sequences[t[0]:new_tuple[0]]
|
94 |
-
part2 = sequences[new_tuple[0]:new_tuple[1]]
|
95 |
-
part3 = sequences[new_tuple[1]:t[1]]
|
96 |
-
sequence = np.concatenate([part1, [self.mask_tokens["<mask-1>"]], part3, [self.mask_tokens["<mask-1>"]], part2])
|
97 |
-
position_ids_sequence = None
|
98 |
-
if self.add_position_ids:
|
99 |
-
position_ids = np.arange(t[0], t[1]) - t[0]
|
100 |
-
position_ids_part1 = position_ids[t[0]:new_tuple[0]]
|
101 |
-
position_ids_part2 = position_ids[new_tuple[0]:new_tuple[1]]
|
102 |
-
position_ids_part3 = position_ids[new_tuple[1]:t[1]]
|
103 |
-
position_ids_sequence = np.concatenate(
|
104 |
-
[position_ids_part1, [position_ids_part2[0]], position_ids_part3, [position_ids_part2[0]],
|
105 |
-
position_ids_part2])
|
106 |
-
|
107 |
-
return sequence, position_ids_sequence
|
108 |
-
|
109 |
-
|
110 |
-
class MultipleSpanFIM(AbstractFIM):
|
111 |
-
def __init__(self,
|
112 |
-
max_patches=5,
|
113 |
-
mask_fraction=0.2,
|
114 |
-
always_mask=False,
|
115 |
-
mask_tokens=MASK_TO_ID,
|
116 |
-
eos_token=AA_TO_ID["<eos>"],
|
117 |
-
add_position_ids=False,
|
118 |
-
troubleshoot=False):
|
119 |
-
super().__init__(max_patches, mask_fraction, always_mask, mask_tokens, eos_token, add_position_ids, troubleshoot)
|
120 |
-
|
121 |
-
def fim(self, sequences, t):
|
122 |
-
"""
|
123 |
-
This function creates and concatenates parts of the sequences based on the inpaint scrambling strategy.
|
124 |
-
It randomly selects `2*num_patches` indices within the range of the given tuple,
|
125 |
-
splits the sequence into unmasked and masked parts based on these indices, and then concatenates them.
|
126 |
-
The number of patches is sampled from a poisson distribution with upper limit `self.max_patches` and average 1.
|
127 |
-
The concatenation is done by joining all unmaksed parts (interleaved with mask tokens) and afterwards
|
128 |
-
all masked parts (interleaved with mask tokens). At the end of the unmasked parts, a special token is added
|
129 |
-
to indicate the end of the unmasked parts, and at the end of the masked parts, a special token is added
|
130 |
-
to indicate the end of the masked parts.
|
131 |
-
"""
|
132 |
-
# sample num_patches from a discrete poisson distribution with upper limit L
|
133 |
-
def sample_lengths(start, end):
|
134 |
-
"""
|
135 |
-
Sample a length uniformly from 1 to max_L*self.mask_fraction (must be bigger than 1).
|
136 |
-
If the length is larger than max_L, return max_L.
|
137 |
-
"""
|
138 |
-
max_L = end - start
|
139 |
-
length = np.random.randint(1, max(int(max_L * self.mask_fraction), 2))
|
140 |
-
return min(length, max_L)
|
141 |
-
|
142 |
-
# sample num_patches from a discrete poisson distribution with upper limit max_patches
|
143 |
-
num_patches = 1000
|
144 |
-
while num_patches > self.max_patches:
|
145 |
-
num_patches = np.random.poisson(1)
|
146 |
-
if self.always_mask:
|
147 |
-
num_patches = max(num_patches, 1)
|
148 |
-
# sample num_patches starting points for the masked positions (+ final position)
|
149 |
-
start_patches = list(np.sort(np.random.choice(np.arange(t[0] + 1, t[1]),
|
150 |
-
num_patches,
|
151 |
-
replace=False))) + [t[1]]
|
152 |
-
# sample num_patches lengths of the patches
|
153 |
-
len_patches = [sample_lengths(start_patches[i], start_patches[i + 1])
|
154 |
-
for i in range(len(start_patches) - 1)]
|
155 |
-
# create masked tuples with start and end indices of the patches
|
156 |
-
masked_tuples = [(start_patches[i], start_patches[i] + len_patches[i]) for i in range(len(start_patches) - 1)]
|
157 |
-
# split the sequences into unmasked and masked parts
|
158 |
-
unmasked_sequence, masked_sequence, unmasked_position_ids, masked_position_ids = self.split_sequences(sequences,
|
159 |
-
t,
|
160 |
-
masked_tuples)
|
161 |
-
|
162 |
-
if self.troubleshoot:
|
163 |
-
print(f"For sequence in {t}: sampled {num_patches=}, {start_patches=}, {len_patches=}, {masked_tuples=}")
|
164 |
-
# concatenate the unmasked and masked parts
|
165 |
-
return unmasked_sequence + masked_sequence, unmasked_position_ids + masked_position_ids if self.add_position_ids else None
|
166 |
-
|
167 |
-
def split_sequences(self, sequences, t, masked_tuples):
|
168 |
-
"""
|
169 |
-
This function splits the sequences into unmasked and masked parts based on the given tuples.
|
170 |
-
Args:
|
171 |
-
t (tuple): The start and end index of each sequence.
|
172 |
-
masked_tuples (list): A list of tuples specifying the indices for masked regions.
|
173 |
-
Returns:
|
174 |
-
unmasked_parts (list): The unmasked parts of the sequences interleaved with mask_tokens.
|
175 |
-
masked_parts (list): The masked parts of the sequences interleaved with mask_tokens.
|
176 |
-
"""
|
177 |
-
unmasked_parts, masked_parts = [], []
|
178 |
-
unmasked_positions, masked_positions = [], []
|
179 |
-
position_ids = None
|
180 |
-
start, end = t
|
181 |
-
if self.add_position_ids:
|
182 |
-
position_ids = np.arange(start, end) - start
|
183 |
-
for i, region in enumerate(masked_tuples):
|
184 |
-
mask_token = self.mask_tokens[f"<mask-{i + 1}>"]
|
185 |
-
unmasked_parts.extend(sequences[start:region[0]])
|
186 |
-
unmasked_parts.append(mask_token)
|
187 |
-
masked_parts.append(mask_token)
|
188 |
-
masked_parts.extend(sequences[region[0]:region[1]])
|
189 |
-
if self.add_position_ids:
|
190 |
-
unmasked_positions.extend(position_ids[start-t[0]:region[0]-t[0]])
|
191 |
-
unmasked_positions.append(position_ids[region[0]-t[0]])
|
192 |
-
masked_positions.append(position_ids[region[0]-t[0]])
|
193 |
-
masked_positions.extend(position_ids[region[0]-t[0]:region[1]-t[0]])
|
194 |
-
|
195 |
-
start = region[1]
|
196 |
-
unmasked_parts.extend(sequences[start:end])
|
197 |
-
if self.add_position_ids:
|
198 |
-
unmasked_positions.extend(position_ids[start-t[0]:end-t[0]])
|
199 |
-
if len(masked_tuples) > 0:
|
200 |
-
unmasked_parts.append(self.eos_token)
|
201 |
-
if self.add_position_ids:
|
202 |
-
unmasked_positions.append(0)
|
203 |
-
return unmasked_parts, masked_parts, unmasked_positions, masked_positions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
protxlstm/index.html
DELETED
@@ -1,16 +0,0 @@
|
|
1 |
-
<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 3.2 Final//EN">
|
2 |
-
<html>
|
3 |
-
<head>
|
4 |
-
<title>Index of /research/Bio-xLSTM/downloads/Prot-xLSTM/checkpoints/protxlstm_26M_30B</title>
|
5 |
-
</head>
|
6 |
-
<body>
|
7 |
-
<h1>Index of /research/Bio-xLSTM/downloads/Prot-xLSTM/checkpoints/protxlstm_26M_30B</h1>
|
8 |
-
<pre><img src="/icons/blank.gif" alt="Icon "> <a href="?C=N;O=D">Name</a> <a href="?C=M;O=A">Last modified</a> <a href="?C=S;O=A">Size</a> <a href="?C=D;O=A">Description</a><hr><img src="/icons/back.gif" alt="[PARENTDIR]"> <a href="/research/Bio-xLSTM/downloads/Prot-xLSTM/checkpoints/">Parent Directory</a> -
|
9 |
-
<img src="/icons/unknown.gif" alt="[ ]"> <a href="config.json">config.json</a> 2024-11-04 14:36 1.8K
|
10 |
-
<img src="/icons/unknown.gif" alt="[ ]"> <a href="optimizer.pt">optimizer.pt</a> 2024-11-04 14:36 198M
|
11 |
-
<img src="/icons/binary.gif" alt="[ ]"> <a href="pytorch_model.bin">pytorch_model.bin</a> 2024-11-04 14:36 99M
|
12 |
-
<img src="/icons/unknown.gif" alt="[ ]"> <a href="rng_state.pth">rng_state.pth</a> 2024-11-04 14:36 14K
|
13 |
-
<img src="/icons/unknown.gif" alt="[ ]"> <a href="scheduler.pt">scheduler.pt</a> 2024-11-04 14:36 1.0K
|
14 |
-
<img src="/icons/unknown.gif" alt="[ ]"> <a href="trainer_state.json">trainer_state.json</a> 2024-11-04 14:36 2.4M
|
15 |
-
<hr></pre>
|
16 |
-
</body></html>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
protxlstm/models/llama.py
DELETED
@@ -1,342 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
import math
|
3 |
-
import os
|
4 |
-
from collections import namedtuple
|
5 |
-
from typing import Optional, Tuple
|
6 |
-
|
7 |
-
import torch
|
8 |
-
import torch.nn as nn
|
9 |
-
import torch.nn.functional as F
|
10 |
-
from transformers import PretrainedConfig
|
11 |
-
|
12 |
-
from protxlstm.xlstm.components.rotary_position import compute_freqs_cis
|
13 |
-
|
14 |
-
# Note: generation capabilities are not implemented for the transformer
|
15 |
-
|
16 |
-
class TransformerConfig(PretrainedConfig):
|
17 |
-
|
18 |
-
model_type = "llama"
|
19 |
-
|
20 |
-
def __init__(
|
21 |
-
self,
|
22 |
-
d_model,
|
23 |
-
n_layer,
|
24 |
-
n_heads,
|
25 |
-
n_kv_heads,
|
26 |
-
bidirectional,
|
27 |
-
vocab_size,
|
28 |
-
hidden_dim,
|
29 |
-
multiple_of, # MLP hidden layer size will be multiple of
|
30 |
-
norm_eps,
|
31 |
-
max_length,
|
32 |
-
dropout,
|
33 |
-
max_position_embeddings,
|
34 |
-
rope_base_frequency,
|
35 |
-
**kwargs
|
36 |
-
):
|
37 |
-
super().__init__(**kwargs)
|
38 |
-
|
39 |
-
# default hyperparameters for the Llama 7B model
|
40 |
-
self.dim = d_model
|
41 |
-
self.n_layers = n_layer
|
42 |
-
self.n_heads = n_heads
|
43 |
-
self.n_kv_heads = n_kv_heads
|
44 |
-
self.causal_attention = not bidirectional
|
45 |
-
self.vocab_size = vocab_size
|
46 |
-
self.hidden_dim = hidden_dim
|
47 |
-
self.multiple_of = multiple_of
|
48 |
-
self.norm_eps = norm_eps
|
49 |
-
self.max_seq_len = max_length
|
50 |
-
self.dropout = dropout
|
51 |
-
self.max_position_embeddings = max_position_embeddings
|
52 |
-
self.rope_base_frequency = rope_base_frequency
|
53 |
-
|
54 |
-
class RMSNorm_transformer(torch.nn.Module):
|
55 |
-
def __init__(self, dim: int, eps: float):
|
56 |
-
super().__init__()
|
57 |
-
self.eps = eps
|
58 |
-
self.weight = nn.Parameter(torch.ones(dim))
|
59 |
-
|
60 |
-
def _norm(self, x):
|
61 |
-
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
62 |
-
|
63 |
-
def forward(self, x):
|
64 |
-
output = self._norm(x.float()).type_as(x)
|
65 |
-
return output * self.weight
|
66 |
-
|
67 |
-
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
68 |
-
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
69 |
-
t = torch.arange(end, device=freqs.device) # type: ignore
|
70 |
-
freqs = torch.outer(t, freqs).float() # type: ignore
|
71 |
-
freqs_cos = torch.cos(freqs) # real part
|
72 |
-
freqs_sin = torch.sin(freqs) # imaginary part
|
73 |
-
return freqs_cos, freqs_sin
|
74 |
-
|
75 |
-
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
76 |
-
ndim = x.ndim
|
77 |
-
assert 0 <= 1 < ndim
|
78 |
-
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
79 |
-
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
80 |
-
return freqs_cis.view(shape)
|
81 |
-
|
82 |
-
def apply_rotary_emb(
|
83 |
-
xq: torch.Tensor,
|
84 |
-
xk: torch.Tensor,
|
85 |
-
freqs_cos: torch.Tensor,
|
86 |
-
freqs_sin: torch.Tensor
|
87 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
88 |
-
|
89 |
-
# reshape xq and xk to match the complex representation
|
90 |
-
xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
|
91 |
-
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
|
92 |
-
|
93 |
-
# reshape freqs_cos and freqs_sin for broadcasting
|
94 |
-
freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
|
95 |
-
freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
|
96 |
-
|
97 |
-
# apply rotation using real numbers
|
98 |
-
xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
|
99 |
-
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
|
100 |
-
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
|
101 |
-
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
|
102 |
-
|
103 |
-
# flatten last two dimensions
|
104 |
-
xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
|
105 |
-
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
|
106 |
-
|
107 |
-
return xq_out.type_as(xq), xk_out.type_as(xk)
|
108 |
-
|
109 |
-
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
110 |
-
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
111 |
-
bs, slen, n_kv_heads, head_dim = x.shape
|
112 |
-
if n_rep == 1:
|
113 |
-
return x
|
114 |
-
return (
|
115 |
-
x[:, :, :, None, :]
|
116 |
-
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
117 |
-
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
118 |
-
)
|
119 |
-
|
120 |
-
class Attention(nn.Module):
|
121 |
-
def __init__(self, args: TransformerConfig):
|
122 |
-
super().__init__()
|
123 |
-
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
124 |
-
assert args.n_heads % self.n_kv_heads == 0
|
125 |
-
model_parallel_size = 1
|
126 |
-
self.n_local_heads = args.n_heads // model_parallel_size
|
127 |
-
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
128 |
-
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
129 |
-
self.head_dim = args.dim // args.n_heads
|
130 |
-
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
|
131 |
-
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
132 |
-
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
133 |
-
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
|
134 |
-
self.attn_dropout = nn.Dropout(args.dropout)
|
135 |
-
self.resid_dropout = nn.Dropout(args.dropout)
|
136 |
-
self.dropout = args.dropout
|
137 |
-
self.causal_attention = args.causal_attention
|
138 |
-
|
139 |
-
# use flash attention or a manual implementation?
|
140 |
-
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
141 |
-
if not self.flash and self.causal_attention:
|
142 |
-
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
143 |
-
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
144 |
-
mask = torch.triu(mask, diagonal=1)
|
145 |
-
self.register_buffer("mask", mask)
|
146 |
-
|
147 |
-
def forward(
|
148 |
-
self,
|
149 |
-
x: torch.Tensor,
|
150 |
-
freqs_cos: torch.Tensor,
|
151 |
-
freqs_sin: torch.Tensor,
|
152 |
-
):
|
153 |
-
bsz, seqlen, _ = x.shape
|
154 |
-
|
155 |
-
# QKV
|
156 |
-
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
157 |
-
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
158 |
-
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
159 |
-
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
160 |
-
|
161 |
-
# RoPE relative positional embeddings
|
162 |
-
xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
|
163 |
-
|
164 |
-
# grouped multiquery attention: expand out keys and values
|
165 |
-
xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
166 |
-
xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
167 |
-
|
168 |
-
# make heads into a batch dimension
|
169 |
-
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
|
170 |
-
xk = xk.transpose(1, 2)
|
171 |
-
xv = xv.transpose(1, 2)
|
172 |
-
|
173 |
-
# flash implementation
|
174 |
-
if self.flash:
|
175 |
-
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=self.causal_attention)
|
176 |
-
else:
|
177 |
-
# manual implementation
|
178 |
-
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
|
179 |
-
if self.causal_attention:
|
180 |
-
scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
|
181 |
-
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
182 |
-
scores = self.attn_dropout(scores)
|
183 |
-
output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)
|
184 |
-
|
185 |
-
# restore time as batch dimension and concat heads
|
186 |
-
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
187 |
-
|
188 |
-
# final projection into the residual stream
|
189 |
-
output = self.wo(output)
|
190 |
-
output = self.resid_dropout(output)
|
191 |
-
return output
|
192 |
-
|
193 |
-
class FeedForward(nn.Module):
|
194 |
-
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
|
195 |
-
super().__init__()
|
196 |
-
if hidden_dim is None:
|
197 |
-
hidden_dim = 4 * dim
|
198 |
-
hidden_dim = int(2 * hidden_dim / 3)
|
199 |
-
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
200 |
-
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
201 |
-
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
202 |
-
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
203 |
-
self.dropout = nn.Dropout(dropout)
|
204 |
-
|
205 |
-
def forward(self, x):
|
206 |
-
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
|
207 |
-
|
208 |
-
class TransformerBlock(nn.Module):
|
209 |
-
def __init__(self, layer_id: int, args: TransformerConfig):
|
210 |
-
super().__init__()
|
211 |
-
self.n_heads = args.n_heads
|
212 |
-
self.dim = args.dim
|
213 |
-
self.head_dim = args.dim // args.n_heads
|
214 |
-
self.attention = Attention(args)
|
215 |
-
self.feed_forward = FeedForward(
|
216 |
-
dim=args.dim,
|
217 |
-
hidden_dim=args.hidden_dim,
|
218 |
-
multiple_of=args.multiple_of,
|
219 |
-
dropout=args.dropout,
|
220 |
-
)
|
221 |
-
self.layer_id = layer_id
|
222 |
-
self.attention_norm = RMSNorm_transformer(args.dim, eps=args.norm_eps)
|
223 |
-
self.ffn_norm = RMSNorm_transformer(args.dim, eps=args.norm_eps)
|
224 |
-
|
225 |
-
def forward(self, x, freqs_cos, freqs_sin):
|
226 |
-
h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
|
227 |
-
out = h + self.feed_forward.forward(self.ffn_norm(h))
|
228 |
-
return out
|
229 |
-
|
230 |
-
class Transformer(nn.Module):
|
231 |
-
|
232 |
-
last_loss: Optional[torch.Tensor]
|
233 |
-
|
234 |
-
def __init__(self, params: TransformerConfig):
|
235 |
-
super().__init__()
|
236 |
-
self.params = params
|
237 |
-
self.vocab_size = params.vocab_size
|
238 |
-
self.n_layers = params.n_layers
|
239 |
-
|
240 |
-
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
|
241 |
-
self.dropout = nn.Dropout(params.dropout)
|
242 |
-
self.layers = torch.nn.ModuleList()
|
243 |
-
for layer_id in range(params.n_layers):
|
244 |
-
self.layers.append(TransformerBlock(layer_id, params))
|
245 |
-
self.layer_head_dim = self.layers[0].head_dim
|
246 |
-
|
247 |
-
self.norm = RMSNorm_transformer(params.dim, eps=params.norm_eps)
|
248 |
-
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
249 |
-
|
250 |
-
# share the unembedding parameters with the embedding parameters
|
251 |
-
self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying
|
252 |
-
|
253 |
-
# some useful precompute for the RoPE relative positional embeddings
|
254 |
-
# freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)
|
255 |
-
# self.register_buffer("freqs_cos", freqs_cos, persistent=False)
|
256 |
-
# self.register_buffer("freqs_sin", freqs_sin, persistent=False)
|
257 |
-
|
258 |
-
# init all weights
|
259 |
-
self.apply(self._init_weights)
|
260 |
-
# apply special scaled init to the residual projections, per GPT-2 paper
|
261 |
-
for pn, p in self.named_parameters():
|
262 |
-
if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
|
263 |
-
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * params.n_layers))
|
264 |
-
|
265 |
-
# Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor.
|
266 |
-
self.last_loss = None
|
267 |
-
|
268 |
-
def _init_weights(self, module):
|
269 |
-
if isinstance(module, nn.Linear):
|
270 |
-
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
271 |
-
if module.bias is not None:
|
272 |
-
torch.nn.init.zeros_(module.bias)
|
273 |
-
elif isinstance(module, nn.Embedding):
|
274 |
-
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
275 |
-
|
276 |
-
def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
|
277 |
-
_bsz, seqlen = tokens.shape
|
278 |
-
h = self.tok_embeddings(tokens)
|
279 |
-
h = self.dropout(h)
|
280 |
-
# freqs_cos = self.freqs_cos[:seqlen]
|
281 |
-
# freqs_sin = self.freqs_sin[:seqlen]
|
282 |
-
|
283 |
-
if 'position_ids' in kwargs:
|
284 |
-
freqs_cos, freqs_sin = compute_freqs_cis(kwargs.pop("position_ids"), self.layer_head_dim, theta=self.params.rope_base_frequency)
|
285 |
-
else:
|
286 |
-
raise ValueError('Llama model only implemented with RoPEs')
|
287 |
-
|
288 |
-
freqs_cos = freqs_cos.squeeze()
|
289 |
-
freqs_sin = freqs_sin.squeeze()
|
290 |
-
|
291 |
-
for layer in self.layers:
|
292 |
-
h = layer(h, freqs_cos, freqs_sin)
|
293 |
-
h = self.norm(h)
|
294 |
-
|
295 |
-
if targets is not None:
|
296 |
-
logits = self.output(h)
|
297 |
-
self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
298 |
-
else:
|
299 |
-
logits = self.output(h)
|
300 |
-
self.last_loss = None
|
301 |
-
|
302 |
-
return logits
|
303 |
-
|
304 |
-
class TransformerLMHeadModel(nn.Module):
|
305 |
-
|
306 |
-
def __init__(
|
307 |
-
self,
|
308 |
-
config: TransformerConfig,
|
309 |
-
) -> None:
|
310 |
-
|
311 |
-
super().__init__()
|
312 |
-
|
313 |
-
self.config = config
|
314 |
-
|
315 |
-
self.backbone = Transformer(config)
|
316 |
-
|
317 |
-
def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
|
318 |
-
"""
|
319 |
-
num_last_tokens: if > 0, only return the logits for the last n tokens
|
320 |
-
"""
|
321 |
-
|
322 |
-
lm_logits = self.backbone(input_ids, position_ids=position_ids)
|
323 |
-
|
324 |
-
CausalLMOutput = namedtuple("CausalLMOutput", ["loss", "logits"])
|
325 |
-
return CausalLMOutput(loss=None, logits=lm_logits)
|
326 |
-
|
327 |
-
def save_pretrained(self, save_directory):
|
328 |
-
"""
|
329 |
-
Save the model and its configuration file to a directory.
|
330 |
-
"""
|
331 |
-
|
332 |
-
# Ensure save_directory exists
|
333 |
-
os.makedirs(save_directory, exist_ok=True)
|
334 |
-
|
335 |
-
# Save the model's state_dict
|
336 |
-
model_path = os.path.join(save_directory, "pytorch_model.bin")
|
337 |
-
torch.save(self.state_dict(), model_path)
|
338 |
-
|
339 |
-
# Save the configuration of the model
|
340 |
-
config_path = os.path.join(save_directory, "config.json")
|
341 |
-
with open(config_path, "w") as f:
|
342 |
-
json.dump(self.config.to_dict(), f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
protxlstm/models/mamba.py
DELETED
@@ -1,833 +0,0 @@
|
|
1 |
-
# Original code from ProtMamba under Apache License 2.0.
|
2 |
-
|
3 |
-
import json
|
4 |
-
import os
|
5 |
-
from collections import namedtuple
|
6 |
-
from dataclasses import dataclass, field
|
7 |
-
from functools import partial
|
8 |
-
|
9 |
-
from mamba_ssm.models.config_mamba import MambaConfig
|
10 |
-
from mamba_ssm.modules.mamba_simple import Block, Mamba
|
11 |
-
from mamba_ssm.models.mixer_seq_simple import MixerModel, _init_weights
|
12 |
-
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
|
13 |
-
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
|
14 |
-
import torch
|
15 |
-
import torch.nn as nn
|
16 |
-
from torch.utils.checkpoint import checkpoint
|
17 |
-
from transformers import PretrainedConfig
|
18 |
-
|
19 |
-
from protxlstm.generation import GenerationMixinSafe
|
20 |
-
|
21 |
-
@dataclass
|
22 |
-
class MambaConfig(PretrainedConfig):
|
23 |
-
d_model: int = 2560
|
24 |
-
n_layer: int = 64
|
25 |
-
vocab_size: int = 50277
|
26 |
-
ssm_cfg: dict = field(default_factory=dict)
|
27 |
-
rms_norm: bool = True
|
28 |
-
residual_in_fp32: bool = True
|
29 |
-
fused_add_norm: bool = True
|
30 |
-
pad_vocab_size_multiple: int = 8
|
31 |
-
max_position_embeddings: int = 2048
|
32 |
-
|
33 |
-
def create_block(
|
34 |
-
d_model,
|
35 |
-
ssm_cfg=None,
|
36 |
-
norm_epsilon=1e-5,
|
37 |
-
rms_norm=False,
|
38 |
-
residual_in_fp32=False,
|
39 |
-
fused_add_norm=False,
|
40 |
-
layer_idx=None,
|
41 |
-
device=None,
|
42 |
-
dtype=None,
|
43 |
-
checkpoint_mixer=False,
|
44 |
-
):
|
45 |
-
if ssm_cfg is None:
|
46 |
-
ssm_cfg = {}
|
47 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
48 |
-
mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
|
49 |
-
norm_cls = partial(
|
50 |
-
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
|
51 |
-
)
|
52 |
-
block = Block(
|
53 |
-
d_model,
|
54 |
-
mixer_cls,
|
55 |
-
norm_cls=norm_cls,
|
56 |
-
fused_add_norm=fused_add_norm,
|
57 |
-
residual_in_fp32=residual_in_fp32,
|
58 |
-
)
|
59 |
-
block.layer_idx = layer_idx
|
60 |
-
if checkpoint_mixer:
|
61 |
-
block.mixer = CheckpointedModule(block.mixer)
|
62 |
-
return block
|
63 |
-
|
64 |
-
class CheckpointedModule(torch.nn.Module):
|
65 |
-
def __init__(self, layer):
|
66 |
-
super().__init__()
|
67 |
-
self.ckpt_layer = layer
|
68 |
-
|
69 |
-
def forward(self, x, *args, **kwargs):
|
70 |
-
return checkpoint(self.ckpt_layer, x, use_reentrant=False)
|
71 |
-
|
72 |
-
# def state_dict(self, **kwargs):
|
73 |
-
# # Get the state dict of the underlying layer
|
74 |
-
# layer_state_dict = self.ckpt_layer.state_dict(**kwargs)
|
75 |
-
# # Create a new state dict with the original keys
|
76 |
-
# state_dict = {k.replace('ckpt_layer.', ''): v for k, v in layer_state_dict.items()}
|
77 |
-
# return state_dict
|
78 |
-
|
79 |
-
class MixerModelSafe(MixerModel):
|
80 |
-
"""
|
81 |
-
Overwrite the forward method to allow saving intermediate layers.
|
82 |
-
"""
|
83 |
-
|
84 |
-
def forward(self, input_ids, inference_params=None, save_layer=[]):
|
85 |
-
hidden_states = self.embedding(input_ids)
|
86 |
-
residual = None
|
87 |
-
if len(save_layer) > 0:
|
88 |
-
hidden_states_dict = {}
|
89 |
-
for i, layer in enumerate(self.layers):
|
90 |
-
hidden_states, residual = layer(
|
91 |
-
hidden_states, residual, inference_params=inference_params
|
92 |
-
)
|
93 |
-
if i + 1 in save_layer:
|
94 |
-
hidden_states_dict[i + 1] = (
|
95 |
-
hidden_states.detach().cpu().to(torch.float).numpy()
|
96 |
-
)
|
97 |
-
if len(save_layer) > 0:
|
98 |
-
return hidden_states_dict
|
99 |
-
|
100 |
-
if not self.fused_add_norm:
|
101 |
-
residual = (
|
102 |
-
(hidden_states + residual) if residual is not None else hidden_states
|
103 |
-
)
|
104 |
-
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
105 |
-
else:
|
106 |
-
# Set prenorm=False here since we don't need the residual
|
107 |
-
fused_add_norm_fn = (
|
108 |
-
rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
|
109 |
-
)
|
110 |
-
hidden_states = fused_add_norm_fn(
|
111 |
-
hidden_states,
|
112 |
-
self.norm_f.weight,
|
113 |
-
self.norm_f.bias,
|
114 |
-
eps=self.norm_f.eps,
|
115 |
-
residual=residual,
|
116 |
-
prenorm=False,
|
117 |
-
residual_in_fp32=self.residual_in_fp32,
|
118 |
-
)
|
119 |
-
return hidden_states
|
120 |
-
|
121 |
-
class MixerModelWithPosids(nn.Module):
|
122 |
-
r"""Mixer model for Mamba but we add positional encodings to the input embeddings."""
|
123 |
-
|
124 |
-
def __init__(
|
125 |
-
self,
|
126 |
-
d_model: int,
|
127 |
-
n_layer: int,
|
128 |
-
vocab_size: int,
|
129 |
-
max_position_embeddings: int,
|
130 |
-
ssm_cfg=None,
|
131 |
-
norm_epsilon: float = 1e-5,
|
132 |
-
rms_norm: bool = False,
|
133 |
-
initializer_cfg=None,
|
134 |
-
fused_add_norm=False,
|
135 |
-
residual_in_fp32=False,
|
136 |
-
device=None,
|
137 |
-
dtype=None,
|
138 |
-
checkpoint_mixer=False,
|
139 |
-
) -> None:
|
140 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
141 |
-
super().__init__()
|
142 |
-
self.residual_in_fp32 = residual_in_fp32
|
143 |
-
|
144 |
-
self.embedding = nn.Embedding(vocab_size, d_model // 2, **factory_kwargs)
|
145 |
-
self.position_embedding = nn.Embedding(
|
146 |
-
max_position_embeddings, d_model - d_model // 2, **factory_kwargs
|
147 |
-
)
|
148 |
-
|
149 |
-
# We change the order of residual and layer norm:
|
150 |
-
# Instead of LN -> Attn / MLP -> Add, we do:
|
151 |
-
# Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
|
152 |
-
# the main branch (output of MLP / Mixer). The model definition is unchanged.
|
153 |
-
# This is for performance reason: we can fuse add + layer_norm.
|
154 |
-
self.fused_add_norm = fused_add_norm
|
155 |
-
if self.fused_add_norm:
|
156 |
-
if layer_norm_fn is None or rms_norm_fn is None:
|
157 |
-
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
|
158 |
-
|
159 |
-
self.layers = nn.ModuleList(
|
160 |
-
[
|
161 |
-
create_block(
|
162 |
-
d_model,
|
163 |
-
ssm_cfg=ssm_cfg,
|
164 |
-
norm_epsilon=norm_epsilon,
|
165 |
-
rms_norm=rms_norm,
|
166 |
-
residual_in_fp32=residual_in_fp32,
|
167 |
-
fused_add_norm=fused_add_norm,
|
168 |
-
layer_idx=i,
|
169 |
-
checkpoint_mixer=checkpoint_mixer,
|
170 |
-
**factory_kwargs,
|
171 |
-
)
|
172 |
-
for i in range(n_layer)
|
173 |
-
]
|
174 |
-
)
|
175 |
-
|
176 |
-
self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
|
177 |
-
d_model, eps=norm_epsilon, **factory_kwargs
|
178 |
-
)
|
179 |
-
|
180 |
-
self.apply(
|
181 |
-
partial(
|
182 |
-
_init_weights,
|
183 |
-
n_layer=n_layer,
|
184 |
-
**(initializer_cfg if initializer_cfg is not None else {}),
|
185 |
-
)
|
186 |
-
)
|
187 |
-
|
188 |
-
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
189 |
-
return {
|
190 |
-
i: layer.allocate_inference_cache(
|
191 |
-
batch_size, max_seqlen, dtype=dtype, **kwargs
|
192 |
-
)
|
193 |
-
for i, layer in enumerate(self.layers)
|
194 |
-
}
|
195 |
-
|
196 |
-
def forward(self, input_ids, position_ids, inference_params=None, save_layer=[]):
|
197 |
-
hidden_states = torch.cat(
|
198 |
-
[
|
199 |
-
self.embedding(input_ids),
|
200 |
-
self.position_embedding(position_ids),
|
201 |
-
],
|
202 |
-
-1,
|
203 |
-
)
|
204 |
-
residual = None
|
205 |
-
if len(save_layer) > 0:
|
206 |
-
hidden_states_dict = {}
|
207 |
-
for i, layer in enumerate(self.layers):
|
208 |
-
hidden_states, residual = layer(
|
209 |
-
hidden_states, residual, inference_params=inference_params
|
210 |
-
)
|
211 |
-
if i + 1 in save_layer:
|
212 |
-
hidden_states_dict[i + 1] = (
|
213 |
-
hidden_states.detach().cpu().to(torch.float).numpy()
|
214 |
-
)
|
215 |
-
if len(save_layer) > 0:
|
216 |
-
return hidden_states_dict
|
217 |
-
|
218 |
-
if not self.fused_add_norm:
|
219 |
-
residual = (
|
220 |
-
(hidden_states + residual) if residual is not None else hidden_states
|
221 |
-
)
|
222 |
-
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
223 |
-
else:
|
224 |
-
fused_add_norm_fn = (
|
225 |
-
rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
|
226 |
-
)
|
227 |
-
hidden_states = fused_add_norm_fn(
|
228 |
-
hidden_states,
|
229 |
-
self.norm_f.weight,
|
230 |
-
self.norm_f.bias,
|
231 |
-
eps=self.norm_f.eps,
|
232 |
-
residual=residual,
|
233 |
-
prenorm=False,
|
234 |
-
residual_in_fp32=self.residual_in_fp32,
|
235 |
-
)
|
236 |
-
return hidden_states
|
237 |
-
|
238 |
-
class MixerModelWith2DPosids(nn.Module):
|
239 |
-
r"""Mixer model for Mamba but we add positional encodings to the input embeddings."""
|
240 |
-
|
241 |
-
def __init__(
|
242 |
-
self,
|
243 |
-
d_model: int,
|
244 |
-
n_layer: int,
|
245 |
-
vocab_size: int,
|
246 |
-
max_position_embeddings: int,
|
247 |
-
max_sequence_position_embeddings: int = 512,
|
248 |
-
ssm_cfg=None,
|
249 |
-
norm_epsilon: float = 1e-5,
|
250 |
-
rms_norm: bool = False,
|
251 |
-
initializer_cfg=None,
|
252 |
-
fused_add_norm=False,
|
253 |
-
residual_in_fp32=False,
|
254 |
-
device=None,
|
255 |
-
dtype=None,
|
256 |
-
checkpoint_mixer=False,
|
257 |
-
) -> None:
|
258 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
259 |
-
super().__init__()
|
260 |
-
self.residual_in_fp32 = residual_in_fp32
|
261 |
-
|
262 |
-
self.embedding = nn.Embedding(
|
263 |
-
vocab_size, d_model - 2 * d_model // 4, **factory_kwargs
|
264 |
-
)
|
265 |
-
self.position_embedding = nn.Embedding(
|
266 |
-
max_position_embeddings, d_model // 4, **factory_kwargs
|
267 |
-
)
|
268 |
-
self.seq_position_embedding = nn.Embedding(
|
269 |
-
max_sequence_position_embeddings, d_model // 4, **factory_kwargs
|
270 |
-
)
|
271 |
-
self.d_embeddings = d_model - 2 * d_model // 4
|
272 |
-
|
273 |
-
# We change the order of residual and layer norm:
|
274 |
-
# Instead of LN -> Attn / MLP -> Add, we do:
|
275 |
-
# Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
|
276 |
-
# the main branch (output of MLP / Mixer). The model definition is unchanged.
|
277 |
-
# This is for performance reason: we can fuse add + layer_norm.
|
278 |
-
self.fused_add_norm = fused_add_norm
|
279 |
-
if self.fused_add_norm:
|
280 |
-
if layer_norm_fn is None or rms_norm_fn is None:
|
281 |
-
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
|
282 |
-
|
283 |
-
self.layers = nn.ModuleList(
|
284 |
-
[
|
285 |
-
create_block(
|
286 |
-
d_model,
|
287 |
-
ssm_cfg=ssm_cfg,
|
288 |
-
norm_epsilon=norm_epsilon,
|
289 |
-
rms_norm=rms_norm,
|
290 |
-
residual_in_fp32=residual_in_fp32,
|
291 |
-
fused_add_norm=fused_add_norm,
|
292 |
-
layer_idx=i,
|
293 |
-
checkpoint_mixer=checkpoint_mixer,
|
294 |
-
**factory_kwargs,
|
295 |
-
)
|
296 |
-
for i in range(n_layer)
|
297 |
-
]
|
298 |
-
)
|
299 |
-
|
300 |
-
self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
|
301 |
-
d_model, eps=norm_epsilon, **factory_kwargs
|
302 |
-
)
|
303 |
-
|
304 |
-
self.apply(
|
305 |
-
partial(
|
306 |
-
_init_weights,
|
307 |
-
n_layer=n_layer,
|
308 |
-
**(initializer_cfg if initializer_cfg is not None else {}),
|
309 |
-
)
|
310 |
-
)
|
311 |
-
|
312 |
-
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
313 |
-
return {
|
314 |
-
i: layer.allocate_inference_cache(
|
315 |
-
batch_size, max_seqlen, dtype=dtype, **kwargs
|
316 |
-
)
|
317 |
-
for i, layer in enumerate(self.layers)
|
318 |
-
}
|
319 |
-
|
320 |
-
def forward(
|
321 |
-
self,
|
322 |
-
input_ids,
|
323 |
-
position_ids,
|
324 |
-
seq_position_ids,
|
325 |
-
inference_params=None,
|
326 |
-
save_layer=[],
|
327 |
-
):
|
328 |
-
hidden_states = torch.cat(
|
329 |
-
[
|
330 |
-
self.embedding(input_ids),
|
331 |
-
self.position_embedding(position_ids),
|
332 |
-
self.seq_position_embedding(seq_position_ids),
|
333 |
-
],
|
334 |
-
-1,
|
335 |
-
)
|
336 |
-
residual = None
|
337 |
-
if len(save_layer) > 0:
|
338 |
-
hidden_states_dict = {}
|
339 |
-
for i, layer in enumerate(self.layers):
|
340 |
-
hidden_states, residual = layer(
|
341 |
-
hidden_states, residual, inference_params=inference_params
|
342 |
-
)
|
343 |
-
if i + 1 in save_layer:
|
344 |
-
hidden_states_dict[i + 1] = (
|
345 |
-
hidden_states.detach().cpu().to(torch.float).numpy()
|
346 |
-
)
|
347 |
-
if len(save_layer) > 0:
|
348 |
-
return hidden_states_dict
|
349 |
-
|
350 |
-
if not self.fused_add_norm:
|
351 |
-
residual = (
|
352 |
-
(hidden_states + residual) if residual is not None else hidden_states
|
353 |
-
)
|
354 |
-
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
355 |
-
else:
|
356 |
-
fused_add_norm_fn = (
|
357 |
-
rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
|
358 |
-
)
|
359 |
-
hidden_states = fused_add_norm_fn(
|
360 |
-
hidden_states,
|
361 |
-
self.norm_f.weight,
|
362 |
-
self.norm_f.bias,
|
363 |
-
eps=self.norm_f.eps,
|
364 |
-
residual=residual,
|
365 |
-
prenorm=False,
|
366 |
-
residual_in_fp32=self.residual_in_fp32,
|
367 |
-
)
|
368 |
-
return hidden_states
|
369 |
-
|
370 |
-
class MambaLMHeadModelSafe(nn.Module, GenerationMixinSafe):
|
371 |
-
|
372 |
-
def __init__(
|
373 |
-
self,
|
374 |
-
config: MambaConfig,
|
375 |
-
initializer_cfg=None,
|
376 |
-
device=None,
|
377 |
-
dtype=None,
|
378 |
-
checkpoint_mixer=False,
|
379 |
-
) -> None:
|
380 |
-
self.config = config
|
381 |
-
d_model = config.d_model
|
382 |
-
n_layer = config.n_layer
|
383 |
-
vocab_size = config.vocab_size
|
384 |
-
ssm_cfg = config.ssm_cfg
|
385 |
-
rms_norm = config.rms_norm
|
386 |
-
residual_in_fp32 = config.residual_in_fp32
|
387 |
-
fused_add_norm = config.fused_add_norm
|
388 |
-
pad_vocab_size_multiple = config.pad_vocab_size_multiple
|
389 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
390 |
-
if checkpoint_mixer:
|
391 |
-
raise NotImplementedError(
|
392 |
-
"Checkpointing is not yet supported for MambaLMHeadModelSafe"
|
393 |
-
)
|
394 |
-
|
395 |
-
super().__init__()
|
396 |
-
if vocab_size % pad_vocab_size_multiple != 0:
|
397 |
-
vocab_size += pad_vocab_size_multiple - (
|
398 |
-
vocab_size % pad_vocab_size_multiple
|
399 |
-
)
|
400 |
-
self.backbone = MixerModelSafe(
|
401 |
-
d_model=d_model,
|
402 |
-
n_layer=n_layer,
|
403 |
-
vocab_size=vocab_size,
|
404 |
-
ssm_cfg=ssm_cfg,
|
405 |
-
rms_norm=rms_norm,
|
406 |
-
initializer_cfg=initializer_cfg,
|
407 |
-
fused_add_norm=fused_add_norm,
|
408 |
-
residual_in_fp32=residual_in_fp32,
|
409 |
-
**factory_kwargs,
|
410 |
-
)
|
411 |
-
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
412 |
-
|
413 |
-
# Initialize weights and apply final processing
|
414 |
-
self.apply(
|
415 |
-
partial(
|
416 |
-
_init_weights,
|
417 |
-
n_layer=n_layer,
|
418 |
-
**(initializer_cfg if initializer_cfg is not None else {}),
|
419 |
-
)
|
420 |
-
)
|
421 |
-
self.tie_weights()
|
422 |
-
|
423 |
-
def tie_weights(self):
|
424 |
-
self.lm_head.weight = self.backbone.embedding.weight
|
425 |
-
|
426 |
-
def clip_grad_norm_(self, max_norm, norm_type=2.0):
|
427 |
-
r"""Clip the norm of the gradients for the model.
|
428 |
-
Args:
|
429 |
-
max_norm (float or int): The maximum norm of the gradients.
|
430 |
-
The gradients are modified in-place.
|
431 |
-
norm_type (float or int): The type of the used p-norm. Can be 'inf' for infinity norm.
|
432 |
-
Returns:
|
433 |
-
Total norm of the parameters (viewed as a single vector).
|
434 |
-
"""
|
435 |
-
return torch.nn.utils.clip_grad_value_(self.parameters(), max_norm)
|
436 |
-
|
437 |
-
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
438 |
-
return self.backbone.allocate_inference_cache(
|
439 |
-
batch_size, max_seqlen, dtype=dtype, **kwargs
|
440 |
-
)
|
441 |
-
|
442 |
-
def forward(
|
443 |
-
self,
|
444 |
-
input_ids,
|
445 |
-
position_ids=None,
|
446 |
-
inference_params=None,
|
447 |
-
num_last_tokens=0,
|
448 |
-
save_layer=[],
|
449 |
-
*args,
|
450 |
-
**kwargs,
|
451 |
-
):
|
452 |
-
"""
|
453 |
-
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
454 |
-
num_last_tokens: if > 0, only return the logits for the last n tokens
|
455 |
-
"""
|
456 |
-
return self.protected_forward(
|
457 |
-
input_ids, position_ids, inference_params, num_last_tokens, save_layer
|
458 |
-
)
|
459 |
-
|
460 |
-
def protected_forward(
|
461 |
-
self,
|
462 |
-
input_ids,
|
463 |
-
position_ids=None,
|
464 |
-
inference_params=None,
|
465 |
-
num_last_tokens=0,
|
466 |
-
save_layer=[],
|
467 |
-
):
|
468 |
-
hidden_states = self.backbone(
|
469 |
-
input_ids, inference_params=inference_params, save_layer=save_layer
|
470 |
-
)
|
471 |
-
if len(save_layer) > 0:
|
472 |
-
return hidden_states
|
473 |
-
if num_last_tokens > 0:
|
474 |
-
hidden_states = hidden_states[:, -num_last_tokens:]
|
475 |
-
lm_logits = self.lm_head(hidden_states)
|
476 |
-
CausalLMOutput = namedtuple("CausalLMOutput", ["loss", "logits"])
|
477 |
-
return CausalLMOutput(loss=None, logits=lm_logits)
|
478 |
-
|
479 |
-
@classmethod
|
480 |
-
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
|
481 |
-
config_data = load_config_hf(pretrained_model_name)
|
482 |
-
config = MambaConfig(**config_data)
|
483 |
-
model = cls(config, device=device, dtype=dtype, **kwargs)
|
484 |
-
model.load_state_dict(
|
485 |
-
load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype),
|
486 |
-
strict=False,
|
487 |
-
)
|
488 |
-
return model
|
489 |
-
|
490 |
-
def save_pretrained(self, save_directory):
|
491 |
-
"""
|
492 |
-
Minimal implementation of save_pretrained for MambaLMHeadModel.
|
493 |
-
Save the model and its configuration file to a directory.
|
494 |
-
"""
|
495 |
-
# Ensure save_directory exists
|
496 |
-
os.makedirs(save_directory, exist_ok=True)
|
497 |
-
|
498 |
-
# Save the model's state_dict
|
499 |
-
model_path = os.path.join(save_directory, "pytorch_model.bin")
|
500 |
-
torch.save(self.state_dict(), model_path)
|
501 |
-
|
502 |
-
# Save the configuration of the model
|
503 |
-
config_path = os.path.join(save_directory, "config.json")
|
504 |
-
with open(config_path, "w") as f:
|
505 |
-
json.dump(self.config.__dict__, f)
|
506 |
-
|
507 |
-
class MambaLMHeadModelwithPosids(nn.Module, GenerationMixinSafe):
|
508 |
-
|
509 |
-
def __init__(
|
510 |
-
self,
|
511 |
-
config: MambaConfig,
|
512 |
-
initializer_cfg=None,
|
513 |
-
device=None,
|
514 |
-
dtype=None,
|
515 |
-
checkpoint_mixer=False,
|
516 |
-
) -> None:
|
517 |
-
self.config = config
|
518 |
-
d_model = config.d_model
|
519 |
-
n_layer = config.n_layer
|
520 |
-
vocab_size = config.vocab_size
|
521 |
-
max_position_embeddings = config.max_position_embeddings
|
522 |
-
ssm_cfg = config.ssm_cfg
|
523 |
-
rms_norm = config.rms_norm
|
524 |
-
residual_in_fp32 = config.residual_in_fp32
|
525 |
-
fused_add_norm = config.fused_add_norm
|
526 |
-
pad_vocab_size_multiple = config.pad_vocab_size_multiple
|
527 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
528 |
-
|
529 |
-
super().__init__()
|
530 |
-
if vocab_size % pad_vocab_size_multiple != 0:
|
531 |
-
vocab_size += pad_vocab_size_multiple - (
|
532 |
-
vocab_size % pad_vocab_size_multiple
|
533 |
-
)
|
534 |
-
self.backbone = MixerModelWithPosids(
|
535 |
-
d_model=d_model,
|
536 |
-
n_layer=n_layer,
|
537 |
-
vocab_size=vocab_size,
|
538 |
-
max_position_embeddings=max_position_embeddings,
|
539 |
-
ssm_cfg=ssm_cfg,
|
540 |
-
rms_norm=rms_norm,
|
541 |
-
initializer_cfg=initializer_cfg,
|
542 |
-
fused_add_norm=fused_add_norm,
|
543 |
-
residual_in_fp32=residual_in_fp32,
|
544 |
-
checkpoint_mixer=checkpoint_mixer,
|
545 |
-
**factory_kwargs,
|
546 |
-
)
|
547 |
-
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
548 |
-
|
549 |
-
# Initialize weights and apply final processing
|
550 |
-
self.apply(
|
551 |
-
partial(
|
552 |
-
_init_weights,
|
553 |
-
n_layer=n_layer,
|
554 |
-
**(initializer_cfg if initializer_cfg is not None else {}),
|
555 |
-
)
|
556 |
-
)
|
557 |
-
self.tie_weights()
|
558 |
-
|
559 |
-
def tie_weights(self):
|
560 |
-
self.lm_head.weight = self.backbone.embedding.weight
|
561 |
-
|
562 |
-
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
563 |
-
return self.backbone.allocate_inference_cache(
|
564 |
-
batch_size, max_seqlen, dtype=dtype, **kwargs
|
565 |
-
)
|
566 |
-
|
567 |
-
def forward(
|
568 |
-
self,
|
569 |
-
input_ids,
|
570 |
-
position_ids=None,
|
571 |
-
inference_params=None,
|
572 |
-
num_last_tokens=0,
|
573 |
-
save_layer=[],
|
574 |
-
*args,
|
575 |
-
**kwargs,
|
576 |
-
):
|
577 |
-
"""
|
578 |
-
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
579 |
-
num_last_tokens: if > 0, only return the logits for the last n tokens
|
580 |
-
"""
|
581 |
-
return self.protected_forward(
|
582 |
-
input_ids, position_ids, inference_params, num_last_tokens, save_layer
|
583 |
-
)
|
584 |
-
|
585 |
-
def protected_forward(
|
586 |
-
self,
|
587 |
-
input_ids,
|
588 |
-
position_ids=None,
|
589 |
-
inference_params=None,
|
590 |
-
num_last_tokens=0,
|
591 |
-
save_layer=[],
|
592 |
-
):
|
593 |
-
hidden_states = self.backbone(
|
594 |
-
input_ids,
|
595 |
-
position_ids=position_ids,
|
596 |
-
inference_params=inference_params,
|
597 |
-
save_layer=save_layer,
|
598 |
-
)
|
599 |
-
if len(save_layer) > 0:
|
600 |
-
return hidden_states
|
601 |
-
hidden_states = hidden_states[:, :, : self.config.d_model // 2]
|
602 |
-
if num_last_tokens > 0:
|
603 |
-
hidden_states = hidden_states[:, -num_last_tokens:]
|
604 |
-
lm_logits = self.lm_head(hidden_states)
|
605 |
-
CausalLMOutput = namedtuple("CausalLMOutput", ["loss", "logits"])
|
606 |
-
return CausalLMOutput(loss=None, logits=lm_logits)
|
607 |
-
|
608 |
-
@classmethod
|
609 |
-
def from_pretrained(
|
610 |
-
cls,
|
611 |
-
pretrained_model_name,
|
612 |
-
device=None,
|
613 |
-
dtype=None,
|
614 |
-
checkpoint_mixer=False,
|
615 |
-
**kwargs,
|
616 |
-
):
|
617 |
-
config_data = load_config_hf(pretrained_model_name)
|
618 |
-
config = MambaConfig(**config_data)
|
619 |
-
model = cls(
|
620 |
-
config,
|
621 |
-
device=device,
|
622 |
-
dtype=dtype,
|
623 |
-
checkpoint_mixer=checkpoint_mixer,
|
624 |
-
**kwargs,
|
625 |
-
)
|
626 |
-
state_dict = load_state_dict_hf(
|
627 |
-
pretrained_model_name, device=device, dtype=dtype
|
628 |
-
)
|
629 |
-
if state_dict.keys() != model.state_dict().keys():
|
630 |
-
if checkpoint_mixer:
|
631 |
-
for key in model.state_dict().keys():
|
632 |
-
if "ckpt_layer" in key:
|
633 |
-
state_dict[key] = state_dict.pop(key.replace("ckpt_layer.", ""))
|
634 |
-
print(
|
635 |
-
"Using a model that was pretrained without gradient checkpointing and now want to use it. Changed the keys of the state_dict to match the model's keys."
|
636 |
-
)
|
637 |
-
else:
|
638 |
-
for key in list(state_dict.keys()):
|
639 |
-
if "ckpt_layer" in key:
|
640 |
-
state_dict[key.replace("ckpt_layer.", "")] = state_dict.pop(key)
|
641 |
-
print(
|
642 |
-
"Using a model that was pretrained with gradient checkpointing but now do not want to use it. Changed the keys of the state_dict to match the model's keys."
|
643 |
-
)
|
644 |
-
assert (
|
645 |
-
state_dict.keys() == model.state_dict().keys()
|
646 |
-
), "The keys of the state_dict do not match the model's keys."
|
647 |
-
model.load_state_dict(state_dict)
|
648 |
-
return model
|
649 |
-
|
650 |
-
def save_pretrained(self, save_directory):
|
651 |
-
"""
|
652 |
-
Minimal implementation of save_pretrained for MambaLMHeadModel.
|
653 |
-
Save the model and its configuration file to a directory.
|
654 |
-
"""
|
655 |
-
# Ensure save_directory exists
|
656 |
-
os.makedirs(save_directory, exist_ok=True)
|
657 |
-
|
658 |
-
# Save the model's state_dict
|
659 |
-
model_path = os.path.join(save_directory, "pytorch_model.bin")
|
660 |
-
torch.save(self.state_dict(), model_path)
|
661 |
-
|
662 |
-
# Save the configuration of the model
|
663 |
-
config_path = os.path.join(save_directory, "config.json")
|
664 |
-
with open(config_path, "w") as f:
|
665 |
-
json.dump(self.config.__dict__, f)
|
666 |
-
|
667 |
-
class MambaLMHeadModelwith2DPosids(nn.Module, GenerationMixinSafe):
|
668 |
-
|
669 |
-
def __init__(
|
670 |
-
self,
|
671 |
-
config: MambaConfig,
|
672 |
-
initializer_cfg=None,
|
673 |
-
device=None,
|
674 |
-
dtype=None,
|
675 |
-
checkpoint_mixer=False,
|
676 |
-
) -> None:
|
677 |
-
self.config = config
|
678 |
-
d_model = config.d_model
|
679 |
-
n_layer = config.n_layer
|
680 |
-
vocab_size = config.vocab_size
|
681 |
-
max_position_embeddings = config.max_position_embeddings
|
682 |
-
ssm_cfg = config.ssm_cfg
|
683 |
-
rms_norm = config.rms_norm
|
684 |
-
residual_in_fp32 = config.residual_in_fp32
|
685 |
-
fused_add_norm = config.fused_add_norm
|
686 |
-
pad_vocab_size_multiple = config.pad_vocab_size_multiple
|
687 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
688 |
-
|
689 |
-
super().__init__()
|
690 |
-
if vocab_size % pad_vocab_size_multiple != 0:
|
691 |
-
vocab_size += pad_vocab_size_multiple - (
|
692 |
-
vocab_size % pad_vocab_size_multiple
|
693 |
-
)
|
694 |
-
self.backbone = MixerModelWith2DPosids(
|
695 |
-
d_model=d_model,
|
696 |
-
n_layer=n_layer,
|
697 |
-
vocab_size=vocab_size,
|
698 |
-
max_position_embeddings=max_position_embeddings,
|
699 |
-
ssm_cfg=ssm_cfg,
|
700 |
-
rms_norm=rms_norm,
|
701 |
-
initializer_cfg=initializer_cfg,
|
702 |
-
fused_add_norm=fused_add_norm,
|
703 |
-
residual_in_fp32=residual_in_fp32,
|
704 |
-
checkpoint_mixer=checkpoint_mixer,
|
705 |
-
**factory_kwargs,
|
706 |
-
)
|
707 |
-
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
708 |
-
|
709 |
-
# Initialize weights and apply final processing
|
710 |
-
self.apply(
|
711 |
-
partial(
|
712 |
-
_init_weights,
|
713 |
-
n_layer=n_layer,
|
714 |
-
**(initializer_cfg if initializer_cfg is not None else {}),
|
715 |
-
)
|
716 |
-
)
|
717 |
-
self.tie_weights()
|
718 |
-
|
719 |
-
def tie_weights(self):
|
720 |
-
self.lm_head.weight = self.backbone.embedding.weight
|
721 |
-
|
722 |
-
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
723 |
-
return self.backbone.allocate_inference_cache(
|
724 |
-
batch_size, max_seqlen, dtype=dtype, **kwargs
|
725 |
-
)
|
726 |
-
|
727 |
-
def forward(
|
728 |
-
self,
|
729 |
-
input_ids,
|
730 |
-
position_ids=None,
|
731 |
-
seq_position_ids=None,
|
732 |
-
inference_params=None,
|
733 |
-
num_last_tokens=0,
|
734 |
-
save_layer=[],
|
735 |
-
*args,
|
736 |
-
**kwargs,
|
737 |
-
):
|
738 |
-
"""
|
739 |
-
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
740 |
-
num_last_tokens: if > 0, only return the logits for the last n tokens
|
741 |
-
"""
|
742 |
-
return self.protected_forward(
|
743 |
-
input_ids,
|
744 |
-
position_ids,
|
745 |
-
seq_position_ids,
|
746 |
-
inference_params,
|
747 |
-
num_last_tokens,
|
748 |
-
save_layer,
|
749 |
-
)
|
750 |
-
|
751 |
-
def protected_forward(
|
752 |
-
self,
|
753 |
-
input_ids,
|
754 |
-
position_ids=None,
|
755 |
-
seq_position_ids=None,
|
756 |
-
inference_params=None,
|
757 |
-
num_last_tokens=0,
|
758 |
-
save_layer=[],
|
759 |
-
):
|
760 |
-
hidden_states = self.backbone(
|
761 |
-
input_ids,
|
762 |
-
position_ids=position_ids,
|
763 |
-
seq_position_ids=seq_position_ids,
|
764 |
-
inference_params=inference_params,
|
765 |
-
save_layer=save_layer,
|
766 |
-
)
|
767 |
-
if len(save_layer) > 0:
|
768 |
-
return hidden_states
|
769 |
-
hidden_states = hidden_states[:, :, : self.backbone.d_embeddings]
|
770 |
-
if num_last_tokens > 0:
|
771 |
-
hidden_states = hidden_states[:, -num_last_tokens:]
|
772 |
-
lm_logits = self.lm_head(hidden_states)
|
773 |
-
CausalLMOutput = namedtuple("CausalLMOutput", ["loss", "logits"])
|
774 |
-
return CausalLMOutput(loss=None, logits=lm_logits)
|
775 |
-
|
776 |
-
@classmethod
|
777 |
-
def from_pretrained(
|
778 |
-
cls,
|
779 |
-
pretrained_model_name,
|
780 |
-
device=None,
|
781 |
-
dtype=None,
|
782 |
-
checkpoint_mixer=False,
|
783 |
-
**kwargs,
|
784 |
-
):
|
785 |
-
config_data = load_config_hf(pretrained_model_name)
|
786 |
-
config = MambaConfig(**config_data)
|
787 |
-
model = cls(
|
788 |
-
config,
|
789 |
-
device=device,
|
790 |
-
dtype=dtype,
|
791 |
-
checkpoint_mixer=checkpoint_mixer,
|
792 |
-
**kwargs,
|
793 |
-
)
|
794 |
-
state_dict = load_state_dict_hf(
|
795 |
-
pretrained_model_name, device=device, dtype=dtype
|
796 |
-
)
|
797 |
-
if state_dict.keys() != model.state_dict().keys():
|
798 |
-
if checkpoint_mixer:
|
799 |
-
for key in model.state_dict().keys():
|
800 |
-
if "ckpt_layer" in key:
|
801 |
-
state_dict[key] = state_dict.pop(key.replace("ckpt_layer.", ""))
|
802 |
-
print(
|
803 |
-
"Using a model that was pretrained without gradient checkpointing and now want to use it. Changed the keys of the state_dict to match the model's keys."
|
804 |
-
)
|
805 |
-
else:
|
806 |
-
for key in list(state_dict.keys()):
|
807 |
-
if "ckpt_layer" in key:
|
808 |
-
state_dict[key.replace("ckpt_layer.", "")] = state_dict.pop(key)
|
809 |
-
print(
|
810 |
-
"Using a model that was pretrained with gradient checkpointing but now do not want to use it. Changed the keys of the state_dict to match the model's keys."
|
811 |
-
)
|
812 |
-
assert (
|
813 |
-
state_dict.keys() == model.state_dict().keys()
|
814 |
-
), "The keys of the state_dict do not match the model's keys."
|
815 |
-
model.load_state_dict(state_dict)
|
816 |
-
return model
|
817 |
-
|
818 |
-
def save_pretrained(self, save_directory):
|
819 |
-
"""
|
820 |
-
Minimal implementation of save_pretrained for MambaLMHeadModel.
|
821 |
-
Save the model and its configuration file to a directory.
|
822 |
-
"""
|
823 |
-
# Ensure save_directory exists
|
824 |
-
os.makedirs(save_directory, exist_ok=True)
|
825 |
-
|
826 |
-
# Save the model's state_dict
|
827 |
-
model_path = os.path.join(save_directory, "pytorch_model.bin")
|
828 |
-
torch.save(self.state_dict(), model_path)
|
829 |
-
|
830 |
-
# Save the configuration of the model
|
831 |
-
config_path = os.path.join(save_directory, "config.json")
|
832 |
-
with open(config_path, "w") as f:
|
833 |
-
json.dump(self.config.__dict__, f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
protxlstm/plot_utils.py
DELETED
@@ -1,26 +0,0 @@
|
|
1 |
-
|
2 |
-
cd = { # use dependent on model-type!!
|
3 |
-
"xLSTM": "#3073AD",
|
4 |
-
"Transformers": "#4B9D7A",
|
5 |
-
"Mamba": "#DF8953",
|
6 |
-
"S4": "#D275AB",
|
7 |
-
"Hyena": "#E86A61",
|
8 |
-
}
|
9 |
-
|
10 |
-
def setup_matplotlib():
|
11 |
-
import matplotlib.pyplot as plt
|
12 |
-
from tueplots import bundles, axes
|
13 |
-
bundles.icml2022()
|
14 |
-
plt.rcParams.update(bundles.icml2022())
|
15 |
-
plt.rcParams.update(axes.lines(base_width=0.5))
|
16 |
-
plt.rcParams["text.usetex"] = False
|
17 |
-
plt.rcParams['font.family'] = "sans-serif"
|
18 |
-
plt.rcParams['font.serif'] = 'Arial'
|
19 |
-
plt.rcParams['legend.edgecolor'] = 'grey'
|
20 |
-
plt.rcParams['legend.framealpha'] = 0.7
|
21 |
-
plt.rcParams['lines.linewidth'] = 1.2
|
22 |
-
plt.rcParams['axes.grid'] = True
|
23 |
-
plt.rcParams['axes.grid.axis'] = 'both'
|
24 |
-
plt.rcParams['grid.alpha'] = 0.2
|
25 |
-
plt.rcParams['axes.grid'] = True
|
26 |
-
plt.rcParams['axes.prop_cycle'] = plt.cycler(color=cd.values())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
protxlstm/train.py
DELETED
@@ -1,338 +0,0 @@
|
|
1 |
-
# Original code from ProtMamba under Apache License 2.0.
|
2 |
-
#
|
3 |
-
# Modifications made by Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen
|
4 |
-
# - Extended to training of xlstm and transformer-based models
|
5 |
-
# - Predefined splits instead of on-the-fly creation
|
6 |
-
# - Option to overwrite config parameters from the command line
|
7 |
-
# - wandb logging
|
8 |
-
|
9 |
-
import argparse
|
10 |
-
import os
|
11 |
-
|
12 |
-
import torch
|
13 |
-
from omegaconf import OmegaConf
|
14 |
-
from transformers import TrainingArguments
|
15 |
-
|
16 |
-
from protxlstm.dataloaders import ProteinMemmapDataset, ProteinDataCollator
|
17 |
-
from protxlstm.models.xlstm import xLSTMConfig, xLSTMLMHeadModel
|
18 |
-
from protxlstm.models.llama import TransformerConfig, TransformerLMHeadModel
|
19 |
-
from protxlstm.trainer import ProtTrainer, EarlyStoppingCallback, get_last_checkpoint
|
20 |
-
from protxlstm.utils import (
|
21 |
-
AA_TO_ID,
|
22 |
-
compute_metrics,
|
23 |
-
is_zero_rank,
|
24 |
-
parse_override_args,
|
25 |
-
print_number_of_parameters,
|
26 |
-
print_zero_rank,
|
27 |
-
set_optimizer_and_scheduler,
|
28 |
-
setup_wandb,
|
29 |
-
load_model,
|
30 |
-
)
|
31 |
-
|
32 |
-
def run(config):
|
33 |
-
"""
|
34 |
-
Run training loop.
|
35 |
-
|
36 |
-
Args:
|
37 |
-
config (dict): dictionary with the configuration parameters.
|
38 |
-
"""
|
39 |
-
|
40 |
-
if config.model_type == 'llama':
|
41 |
-
pe_kwargs = {
|
42 |
-
'max_position_embeddings' : config["model"]["max_position_embeddings"],
|
43 |
-
'add_position_ids' : '1d',
|
44 |
-
}
|
45 |
-
elif config.model_type == 'mamba':
|
46 |
-
from protxlstm.models.mamba import MambaConfig, MambaLMHeadModelSafe, MambaLMHeadModelwithPosids, MambaLMHeadModelwith2DPosids
|
47 |
-
pe_kwargs = {
|
48 |
-
'max_position_embeddings' : config["model"]["max_position_embeddings"],
|
49 |
-
'max_seq_position_embeddings' : config["model"]["max_seq_position_embeddings"],
|
50 |
-
'add_position_ids' : config["model"]["add_position_ids"]
|
51 |
-
}
|
52 |
-
else:
|
53 |
-
position_embeddings = config["model"]["position_embeddings"]
|
54 |
-
assert position_embeddings in ["none", "abs_1d", "abs_2d", "rot_1d", "rot_2d"]
|
55 |
-
if position_embeddings != "none":
|
56 |
-
position_embeddings = position_embeddings.split("_")[-1]
|
57 |
-
pe_kwargs = {
|
58 |
-
'max_position_embeddings' : config["model"]["max_position_embeddings"],
|
59 |
-
'max_seq_position_embeddings' : config["model"]["max_seq_position_embeddings"],
|
60 |
-
'add_position_ids' : position_embeddings
|
61 |
-
}
|
62 |
-
|
63 |
-
# Setup WandB
|
64 |
-
wandb_run_name = setup_wandb(config)
|
65 |
-
|
66 |
-
# Load datasets
|
67 |
-
dataset_params = {
|
68 |
-
"msa_memmap_path": config["msa_memmap_path"],
|
69 |
-
"msa_memmap_meta_path": config["msa_memmap_meta_path"],
|
70 |
-
"sample": config["sample_sequences"],
|
71 |
-
"max_msa_len": config["max_msa_len"],
|
72 |
-
"reverse": False,
|
73 |
-
"seed": config["seed_sequence_sampling"],
|
74 |
-
"troubleshoot": False,
|
75 |
-
"fim_strategy": config["fim_strategy"],
|
76 |
-
"always_mask": config["always_mask"],
|
77 |
-
**pe_kwargs,
|
78 |
-
}
|
79 |
-
train_dataset = ProteinMemmapDataset(subset_path=config["train_set"], **dataset_params)
|
80 |
-
valid_dataset = ProteinMemmapDataset(subset_path=config["valid_set"], **dataset_params)
|
81 |
-
train_eval_dataset = ProteinMemmapDataset(subset_path=config["train_eval_set"], **dataset_params)
|
82 |
-
|
83 |
-
print(f'Train set size: {len(train_dataset)} Train eval set size: {len(train_eval_dataset)} Valid set size: {len(valid_dataset)}')
|
84 |
-
|
85 |
-
assert (
|
86 |
-
len(AA_TO_ID) == config["model"]["vocab_size"]
|
87 |
-
), f"Vocab size in the config file does not match the one in the code. I should be {len(AA_TO_ID)}"
|
88 |
-
|
89 |
-
# Create data collator for batched training
|
90 |
-
data_collator = ProteinDataCollator(max_sequence_length=config["max_msa_len"])
|
91 |
-
|
92 |
-
# Check datatypes
|
93 |
-
if config["dtype"] == "float32":
|
94 |
-
dtype = torch.float32
|
95 |
-
elif config["dtype"] == "bfloat16":
|
96 |
-
dtype = torch.bfloat16
|
97 |
-
else:
|
98 |
-
raise ValueError("dtype must be either float32 or bfloat16")
|
99 |
-
|
100 |
-
# Initialize model
|
101 |
-
if config.model_type == 'xlstm':
|
102 |
-
|
103 |
-
# Load model for finetuning
|
104 |
-
if config.finetune_model_path:
|
105 |
-
# These fields are updated in the config loaded from the checkpoint
|
106 |
-
config_update_kwargs = {
|
107 |
-
"mlstm_backend": config["model"]["mlstm_block"]["mlstm"]["backend"],
|
108 |
-
"mlstm_chunksize": config["model"]["mlstm_block"]["mlstm"]["chunk_size"],
|
109 |
-
"checkpoint_blocks": config["model"]["checkpoint_blocks"],
|
110 |
-
"rope_base_frequency": config["model"]["rope_base_frequency"]
|
111 |
-
}
|
112 |
-
model = load_model(
|
113 |
-
config.finetune_model_path,
|
114 |
-
model_class=xLSTMLMHeadModel,
|
115 |
-
device="cuda",
|
116 |
-
dtype=dtype,
|
117 |
-
**config_update_kwargs
|
118 |
-
)
|
119 |
-
else:
|
120 |
-
# Create new mode
|
121 |
-
xlstm_config = xLSTMConfig().init_from_dict(config["model"])
|
122 |
-
model = xLSTMLMHeadModel(xlstm_config)
|
123 |
-
|
124 |
-
elif config.model_type == 'mamba':
|
125 |
-
|
126 |
-
_mamba_model = {
|
127 |
-
"none": MambaLMHeadModelSafe,
|
128 |
-
"1d": MambaLMHeadModelwithPosids,
|
129 |
-
"2d": MambaLMHeadModelwith2DPosids,
|
130 |
-
}
|
131 |
-
Mamba = _mamba_model[config['model']["add_position_ids"]]
|
132 |
-
|
133 |
-
# Load model for finetuning
|
134 |
-
if config.finetune_model_path:
|
135 |
-
model = load_model(
|
136 |
-
config.finetune_model_path,
|
137 |
-
model_class=Mamba,
|
138 |
-
device="cuda",
|
139 |
-
dtype=dtype,
|
140 |
-
checkpoint_mixer=config["checkpoint_mixer"],
|
141 |
-
)
|
142 |
-
else:
|
143 |
-
# Create new mode
|
144 |
-
mamba_config = MambaConfig(d_model=config['model']["d_model"],
|
145 |
-
n_layer=config['model']["n_layer"],
|
146 |
-
vocab_size=config['model']["vocab_size"],
|
147 |
-
residual_in_fp32=config['model']["residual_in_fp32"])
|
148 |
-
model = Mamba(mamba_config, dtype=dtype, checkpoint_mixer=config['model']["checkpoint_mixer"])
|
149 |
-
|
150 |
-
elif config.model_type == 'llama':
|
151 |
-
|
152 |
-
llama_config = TransformerConfig(
|
153 |
-
d_model=config["model"]["d_model"],
|
154 |
-
n_layer=config["model"]["n_layer"],
|
155 |
-
n_heads=config["model"]["n_heads"],
|
156 |
-
n_kv_heads=config["model"]["n_kv_heads"],
|
157 |
-
bidirectional=config["model"]["bidirectional"],
|
158 |
-
hidden_dim=config["model"]["hidden_dim"],
|
159 |
-
multiple_of=config["model"]["multiple_of"],
|
160 |
-
norm_eps=config["model"]["norm_eps"],
|
161 |
-
max_length=config["model"]["max_length"],
|
162 |
-
vocab_size=config["model"]["vocab_size"],
|
163 |
-
dropout=config["model"]["dropout"],
|
164 |
-
max_position_embeddings=config["model"]["max_position_embeddings"],
|
165 |
-
rope_base_frequency=config["model"]["rope_base_frequency"],
|
166 |
-
|
167 |
-
)
|
168 |
-
|
169 |
-
model = TransformerLMHeadModel(llama_config)
|
170 |
-
|
171 |
-
else:
|
172 |
-
raise ValueError(f"Unsupported model_type: {config.model_type}. Expected 'xlstm', 'mamba', or 'llama'.")
|
173 |
-
|
174 |
-
|
175 |
-
# TODO: Improve what we want print
|
176 |
-
if is_zero_rank():
|
177 |
-
print_number_of_parameters(model)
|
178 |
-
print_zero_rank(f"dtype: {config['dtype']}")
|
179 |
-
print_zero_rank(f"Epochs: {config['num_epochs']}")
|
180 |
-
print_zero_rank(f"Batch size per GPU: {config['batch_size']}")
|
181 |
-
print_zero_rank(f"Gradient accumulation steps: {config['gradient_accumulation_steps']}")
|
182 |
-
eff_batch_size = config["batch_size"] * config["gradient_accumulation_steps"]
|
183 |
-
nr_gpus = torch.cuda.device_count()
|
184 |
-
print_zero_rank(f"GPUS: {nr_gpus}")
|
185 |
-
eff_batch_size *= nr_gpus
|
186 |
-
print_zero_rank(f"Effective batch size: {eff_batch_size}")
|
187 |
-
print_zero_rank(
|
188 |
-
f"Steps per training epoch: {len(train_dataset) // config['batch_size']}, eff. steps: {len(train_dataset) // eff_batch_size}"
|
189 |
-
)
|
190 |
-
print_zero_rank(f"Steps per evaluation epoch: {len(valid_dataset) // config['batch_size']}")
|
191 |
-
print_zero_rank(f"Max MSA length: {config['max_msa_len']}")
|
192 |
-
ev_epochs = round(
|
193 |
-
config["eval_steps"] * config["batch_size"] / len(train_dataset), 3
|
194 |
-
)
|
195 |
-
print_zero_rank(
|
196 |
-
f"Evaluation every {config['eval_steps']} steps, i.e. {ev_epochs} epochs. Effectively every {config['eval_steps']*config['gradient_accumulation_steps']} steps, i.e. {ev_epochs*config['gradient_accumulation_steps']} epochs."
|
197 |
-
)
|
198 |
-
if config.model_type == 'xlstm' and config["model"]["checkpoint_blocks"]:
|
199 |
-
print_zero_rank("Using gradient checkpointing")
|
200 |
-
if config["compute_only_fim_loss"]:
|
201 |
-
print_zero_rank("Computing only FIM loss for training")
|
202 |
-
|
203 |
-
# Training callbacks
|
204 |
-
es_callback = EarlyStoppingCallback(
|
205 |
-
train_path=config["output_dir"] + '/' + wandb_run_name, config=config
|
206 |
-
)
|
207 |
-
callbacks = [es_callback]
|
208 |
-
|
209 |
-
# Optimizer and Schedulers
|
210 |
-
optimizer, scheduler = set_optimizer_and_scheduler(
|
211 |
-
config,
|
212 |
-
len(train_dataset),
|
213 |
-
model.parameters()
|
214 |
-
)
|
215 |
-
|
216 |
-
# Find checkpoint if available
|
217 |
-
last_checkpoint = None
|
218 |
-
if config.finetune_model_path is None:
|
219 |
-
path = os.path.join(config["output_dir"], wandb_run_name)
|
220 |
-
if os.path.exists(path):
|
221 |
-
last_checkpoint = get_last_checkpoint(path)
|
222 |
-
if last_checkpoint is None:
|
223 |
-
print_zero_rank("No checkpoint found, starting training from scratch.")
|
224 |
-
else:
|
225 |
-
print_zero_rank(f"Resuming training from the last checkpoint: {last_checkpoint}")
|
226 |
-
|
227 |
-
# Create trainer
|
228 |
-
trainer = ProtTrainer(
|
229 |
-
model=model,
|
230 |
-
train_dataset=train_dataset,
|
231 |
-
eval_dataset={"valid": valid_dataset, "train": train_eval_dataset},
|
232 |
-
optimizers=(optimizer, scheduler),
|
233 |
-
args=TrainingArguments(
|
234 |
-
run_name=wandb_run_name,
|
235 |
-
local_rank=int(os.getenv('LOCAL_RANK', '0')),
|
236 |
-
learning_rate=config["learning_rate"],
|
237 |
-
num_train_epochs=config["num_epochs"],
|
238 |
-
per_device_train_batch_size=config["batch_size"],
|
239 |
-
per_device_eval_batch_size=config["batch_size"],
|
240 |
-
gradient_accumulation_steps=config["gradient_accumulation_steps"],
|
241 |
-
eval_accumulation_steps=config["eval_accumulation_steps"],
|
242 |
-
eval_strategy="steps",
|
243 |
-
max_grad_norm=config["max_grad_norm"],
|
244 |
-
bf16=config["dtype"] == "bfloat16",
|
245 |
-
dataloader_num_workers=32,
|
246 |
-
logging_steps=config["logging_steps"],
|
247 |
-
eval_steps=config["eval_steps"],
|
248 |
-
save_steps=config["save_steps"],
|
249 |
-
output_dir=config["output_dir"] + '/' + wandb_run_name,
|
250 |
-
logging_dir=config["output_dir"] + '/' + wandb_run_name,
|
251 |
-
report_to="wandb" if is_zero_rank() else None,
|
252 |
-
log_on_each_node=False,
|
253 |
-
overwrite_output_dir=False,
|
254 |
-
push_to_hub=False,
|
255 |
-
label_names=["labels"],
|
256 |
-
),
|
257 |
-
compute_only_fim_loss=config["compute_only_fim_loss"],
|
258 |
-
data_collator=data_collator,
|
259 |
-
compute_metrics=compute_metrics,
|
260 |
-
callbacks=callbacks,
|
261 |
-
)
|
262 |
-
|
263 |
-
# Train model
|
264 |
-
while True:
|
265 |
-
if last_checkpoint is None and trainer.state.global_step == 0:
|
266 |
-
eval_results = trainer.evaluate()
|
267 |
-
print_zero_rank(
|
268 |
-
f">>> Initial validation perplexity: {eval_results['eval_valid_perplexity/batch']:.2f}"
|
269 |
-
)
|
270 |
-
else:
|
271 |
-
print_zero_rank(f"Resuming training from the last checkpoint: {last_checkpoint}")
|
272 |
-
# Train
|
273 |
-
trainer.train(resume_from_checkpoint=last_checkpoint)
|
274 |
-
|
275 |
-
# Break training when the number of epochs is reached
|
276 |
-
if (
|
277 |
-
not es_callback.should_restart
|
278 |
-
or trainer.state.epoch >= config["num_epochs"]
|
279 |
-
):
|
280 |
-
eval_results = trainer.evaluate()
|
281 |
-
print_zero_rank(
|
282 |
-
f">>> Final Perplexity: {eval_results['eval_valid_perplexity/batch']:.2f}"
|
283 |
-
)
|
284 |
-
break
|
285 |
-
# If the training was interrupted because of a loss spike, restart from the last checkpoint
|
286 |
-
last_checkpoint = es_callback.checkpoint_path
|
287 |
-
|
288 |
-
return trainer
|
289 |
-
|
290 |
-
if __name__ == "__main__":
|
291 |
-
|
292 |
-
# Default configuration file paths
|
293 |
-
default_model_config = "configs/xlstm_default_config.yaml"
|
294 |
-
default_train_config = "configs/train_default_config.yaml"
|
295 |
-
|
296 |
-
parser = argparse.ArgumentParser(
|
297 |
-
description="Train or finetune a model with the provided configuration."
|
298 |
-
)
|
299 |
-
parser.add_argument(
|
300 |
-
"--model_config_path",
|
301 |
-
type=str,
|
302 |
-
default=default_model_config,
|
303 |
-
help=f"Path to the model configuration file (default: {default_model_config})"
|
304 |
-
)
|
305 |
-
parser.add_argument(
|
306 |
-
"--train_config_path",
|
307 |
-
type=str,
|
308 |
-
default=default_train_config,
|
309 |
-
help=f"Path to the training and dataset configuration file (default: {default_train_config})"
|
310 |
-
)
|
311 |
-
parser.add_argument(
|
312 |
-
"overrides",
|
313 |
-
nargs=argparse.REMAINDER,
|
314 |
-
help="Override configuration values using key=value format.",
|
315 |
-
)
|
316 |
-
|
317 |
-
args = parser.parse_args()
|
318 |
-
|
319 |
-
# Check if the default config files exist, or raise an error
|
320 |
-
if not os.path.exists(args.model_config_path):
|
321 |
-
raise FileNotFoundError(f"Model config file not found: {args.model_config_path}")
|
322 |
-
if not os.path.exists(args.train_config_path):
|
323 |
-
raise FileNotFoundError(f"Train config file not found: {args.train_config_path}")
|
324 |
-
|
325 |
-
# Load the model and training configurations
|
326 |
-
model_config = OmegaConf.load(args.model_config_path)
|
327 |
-
train_config = OmegaConf.load(args.train_config_path)
|
328 |
-
|
329 |
-
# Merge the model and training configurations
|
330 |
-
config = OmegaConf.merge(model_config, train_config)
|
331 |
-
|
332 |
-
# Parse overrides
|
333 |
-
if args.overrides:
|
334 |
-
overrides = parse_override_args(args.overrides)
|
335 |
-
config.merge_with(OmegaConf.create(overrides))
|
336 |
-
|
337 |
-
# Run the training/finetuning process
|
338 |
-
run(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
protxlstm/trainer.py
DELETED
@@ -1,123 +0,0 @@
|
|
1 |
-
# Original code from ProtMamba under Apache License 2.0.
|
2 |
-
#
|
3 |
-
# Modifications made by Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen
|
4 |
-
# - MambaTrainer renamed to ProtTrainer
|
5 |
-
|
6 |
-
import os
|
7 |
-
import re
|
8 |
-
|
9 |
-
import torch
|
10 |
-
from transformers import Trainer, TrainerCallback
|
11 |
-
|
12 |
-
from protxlstm.utils import AA_TO_ID, find_fim_indices
|
13 |
-
|
14 |
-
class ProtTrainer(Trainer):
|
15 |
-
"""
|
16 |
-
Base HuggingFace Trainer used for training.
|
17 |
-
|
18 |
-
from https://github.com/havenhq/mamba-chat/blob/main/trainer/mamba_trainer.py"""
|
19 |
-
def __init__(self, compute_only_fim_loss, **kwargs,):
|
20 |
-
super().__init__(**kwargs)
|
21 |
-
self.compute_only_fim_loss = compute_only_fim_loss
|
22 |
-
|
23 |
-
|
24 |
-
def compute_loss(self, model, inputs, return_outputs=False):
|
25 |
-
input_ids = inputs.pop("input_ids")
|
26 |
-
labels = inputs.pop("labels")
|
27 |
-
if "seq_position_ids" in inputs and "position_ids" in inputs:
|
28 |
-
position_ids = inputs.pop("position_ids")
|
29 |
-
seq_position_ids = inputs.pop("seq_position_ids")
|
30 |
-
output = model(input_ids, position_ids=position_ids, seq_position_ids=seq_position_ids)
|
31 |
-
elif "position_ids" in inputs:
|
32 |
-
position_ids = inputs.pop("position_ids")
|
33 |
-
output = model(input_ids, position_ids=position_ids)
|
34 |
-
else:
|
35 |
-
output = model(input_ids)
|
36 |
-
lm_logits = output.logits
|
37 |
-
|
38 |
-
labels = labels.to(lm_logits.device)
|
39 |
-
shift_logits = lm_logits[:, :-1, :].contiguous()
|
40 |
-
labels = labels[:, 1:].contiguous()
|
41 |
-
|
42 |
-
loss_fct = torch.nn.CrossEntropyLoss()
|
43 |
-
if self.compute_only_fim_loss:
|
44 |
-
# start and end tokens
|
45 |
-
is_cls_tokens = (labels == AA_TO_ID["<cls>"])
|
46 |
-
is_eos_tokens = (labels == AA_TO_ID["<eos>"])
|
47 |
-
bool_fim = find_fim_indices(is_cls_tokens, is_eos_tokens)
|
48 |
-
# include also the cls token
|
49 |
-
bool_fim = bool_fim | is_cls_tokens
|
50 |
-
inds = torch.where(bool_fim)
|
51 |
-
lm_loss = loss_fct(shift_logits[inds[0], inds[1], :], labels[bool_fim])
|
52 |
-
else:
|
53 |
-
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
|
54 |
-
|
55 |
-
return (lm_loss, output) if return_outputs else lm_loss
|
56 |
-
|
57 |
-
def save_model(self, output_dir, _internal_call):
|
58 |
-
if int(os.getenv('LOCAL_RANK', '0')) == 0:
|
59 |
-
self.model.save_pretrained(output_dir)
|
60 |
-
|
61 |
-
PREFIX_CHECKPOINT_DIR = "checkpoint"
|
62 |
-
_re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$")
|
63 |
-
|
64 |
-
def get_last_checkpoint(folder, max_steps=None):
|
65 |
-
content = os.listdir(folder)
|
66 |
-
checkpoints = [
|
67 |
-
path
|
68 |
-
for path in content
|
69 |
-
if _re_checkpoint.search(path) is not None and os.path.isdir(os.path.join(folder, path))
|
70 |
-
]
|
71 |
-
if len(checkpoints) == 0:
|
72 |
-
return
|
73 |
-
|
74 |
-
max_steps = max_steps if max_steps is not None else float("inf")
|
75 |
-
# func = lambda x: int(_re_checkpoint.search(x).groups()[0])
|
76 |
-
def func(x):
|
77 |
-
num = int(_re_checkpoint.search(x).groups()[0])
|
78 |
-
return num if num < max_steps else -1
|
79 |
-
return os.path.join(folder, max(checkpoints, key=func))
|
80 |
-
|
81 |
-
class EarlyStoppingCallback(TrainerCallback):
|
82 |
-
def __init__(self, train_path, config=None):
|
83 |
-
self.step_counter_reset = 0
|
84 |
-
self.step_counter_stop = 0
|
85 |
-
self.best_loss = None
|
86 |
-
self.train_path = train_path
|
87 |
-
self.patience = config["patience"]
|
88 |
-
self.metric_name = config["early_stopping_metric"]
|
89 |
-
self.checkpoint_path = None
|
90 |
-
self.should_restart = False
|
91 |
-
self.eval_steps = config["eval_steps"]
|
92 |
-
self.loss_increase_factor = config["loss_increase_factor"]
|
93 |
-
|
94 |
-
def get_checkpoint_path(self, max_steps):
|
95 |
-
last_checkpoint = None
|
96 |
-
if os.path.exists(self.train_path):
|
97 |
-
last_checkpoint = get_last_checkpoint(self.train_path, max_steps)
|
98 |
-
if last_checkpoint is None:
|
99 |
-
print("No checkpoint found, starting training from scratch.")
|
100 |
-
else:
|
101 |
-
print(f"Max checkpoint allowed: {max_steps}, restarting from {last_checkpoint}.")
|
102 |
-
return last_checkpoint
|
103 |
-
|
104 |
-
def on_evaluate(self, args, state, control, model, metrics, **kwargs):
|
105 |
-
if self.metric_name in metrics:
|
106 |
-
if self.best_loss is None:
|
107 |
-
self.best_loss = metrics[self.metric_name]
|
108 |
-
elif self.best_loss*self.loss_increase_factor < metrics[self.metric_name]:
|
109 |
-
self.step_counter += 1
|
110 |
-
if self.step_counter >= self.patience:
|
111 |
-
checkpoint_path = self.get_checkpoint_path(max_steps=(state.global_step-self.patience*self.eval_steps))
|
112 |
-
control.should_training_stop = True
|
113 |
-
self.checkpoint_path = checkpoint_path
|
114 |
-
self.should_restart = True
|
115 |
-
else:
|
116 |
-
self.step_counter = 0
|
117 |
-
self.best_loss = min(self.best_loss, metrics[self.metric_name])
|
118 |
-
self.should_restart = False
|
119 |
-
|
120 |
-
def on_train_begin(self, args, state, control, **kwargs):
|
121 |
-
self.step_counter = 0
|
122 |
-
self.best_loss = None
|
123 |
-
self.should_restart = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
run.sh
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
#!/bin/bash
|
2 |
-
CONDA_ENV=$(head -1 /code/environment.yml | cut -d" " -f2)
|
3 |
-
eval "$(conda shell.bash hook)"
|
4 |
-
conda activate $CONDA_ENV
|
5 |
-
|
6 |
-
streamlit run app.py --server.port 7860 --server.address 0.0.0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|