xLSTM-7b / README.md
korbip's picture
Update readme with instructions on how to change the kernels. (#6)
9a2ea4c verified
metadata
license: other

xLSTM-7B

This xLSTM-7B was pre-trained on the DCLM and selected high-quality data for in a total of approx. 2.3 T tokens using the xlstm-jax framework.

How to use it

First, install xlstm, which now uses the mlstm_kernels package for triton kernels:

pip install xlstm
pip install mlstm_kernels

For now, install the transformers repositiory fork from NX-AI (until it is merged):

pip install 'transformers @ git+ssh://[email protected]/NX-AI/transformers.git@integrate_xlstm'

Use this model as:

from transformers import AutoModelForCausalLM, AutoTokenizer

xlstm = AutoModelForCausalLM.from_pretrained("NX-AI/xLSTM-7b", device_map="auto")

# this is a fork of EleutherAI/gpt-neox-20b
tokenizer = AutoTokenizer.from_pretrained("NX-AI/xLSTM-7b")

tokens = tokenizer("Hello xLSTM, how are you doing?", return_tensors='pt')['input_ids'].to(device="cuda")

out = xlstm.generate(tokens, max_new_tokens=20)

print(tokenizer.decode(out[0]))

If you cannot or do not want to use the triton kernels, you can change them to native PyTorch implementations:

xlstm_config = AutoConfig.from_pretrained("NX-AI/xLSTM-7b")
xlstm_config.step_kernel = "native"
xlstm_config.chunkwise_kernel = "chunkwise--native_autograd"
xlstm_config.sequence_kernel = "native_sequence__native"

xlstm = AutoModelForCausalLM.from_pretrained("NX-AI/xLSTM-7b", config=xlstm_config, device_map="auto")

# verify selected kernels
from pprint import pprint
pprint(xlstm.backbone.blocks[0].mlstm_layer.config)

Speed results

Generation Speed using torch.cuda.graph and torch.compile optimizations on one NVIDIA H100: generation speed

Performance

mmlu_train_token

Using HuggingFace's lm_eval:

BBH MMLU-Pro Math MUSR GPQA IfEval
0.381 0.242 0.036 0.379 0.280 0.244

Using HuggingFace's lighteval in the Leaderboard-v1 settings:

Arc-Challenge (25-shot) MMLU (5-shot) Hellaswag (10-shot) Winogrande (5-shot) TruthfulQA (0-shot) GSM8k (5-shot) OpenbookQA (5-shot) PiQA (5-shot)
0.584 0.589 0.710 0.742 0.420 0.004 0.443 0.817

License

NXAI Community License (see LICENSE file)