metadata
license: other
xLSTM-7B
This xLSTM-7B was pre-trained on the DCLM and selected high-quality data for in a total of approx. 2.3 T tokens using the xlstm-jax
framework.
How to use it
First, install xlstm
, which now uses the mlstm_kernels
package for triton kernels:
pip install xlstm
pip install mlstm_kernels
For now, install the transformers repositiory fork from NX-AI (until it is merged):
pip install 'transformers @ git+ssh://[email protected]/NX-AI/transformers.git@integrate_xlstm'
Use this model as:
from transformers import AutoModelForCausalLM, AutoTokenizer
xlstm = AutoModelForCausalLM.from_pretrained("NX-AI/xLSTM-7b", device_map="auto")
# this is a fork of EleutherAI/gpt-neox-20b
tokenizer = AutoTokenizer.from_pretrained("NX-AI/xLSTM-7b")
tokens = tokenizer("Hello xLSTM, how are you doing?", return_tensors='pt')['input_ids'].to(device="cuda")
out = xlstm.generate(tokens, max_new_tokens=20)
print(tokenizer.decode(out[0]))
If you cannot or do not want to use the triton kernels, you can change them to native PyTorch implementations:
xlstm_config = AutoConfig.from_pretrained("NX-AI/xLSTM-7b")
xlstm_config.step_kernel = "native"
xlstm_config.chunkwise_kernel = "chunkwise--native_autograd"
xlstm_config.sequence_kernel = "native_sequence__native"
xlstm = AutoModelForCausalLM.from_pretrained("NX-AI/xLSTM-7b", config=xlstm_config, device_map="auto")
# verify selected kernels
from pprint import pprint
pprint(xlstm.backbone.blocks[0].mlstm_layer.config)
Speed results
Generation Speed using torch.cuda.graph
and torch.compile
optimizations on one NVIDIA H100:
Performance
Using HuggingFace's lm_eval
:
BBH | MMLU-Pro | Math | MUSR | GPQA | IfEval |
---|---|---|---|---|---|
0.381 | 0.242 | 0.036 | 0.379 | 0.280 | 0.244 |
Using HuggingFace's lighteval
in the Leaderboard-v1 settings:
Arc-Challenge (25-shot) | MMLU (5-shot) | Hellaswag (10-shot) | Winogrande (5-shot) | TruthfulQA (0-shot) | GSM8k (5-shot) | OpenbookQA (5-shot) | PiQA (5-shot) |
---|---|---|---|---|---|---|---|
0.584 | 0.589 | 0.710 | 0.742 | 0.420 | 0.004 | 0.443 | 0.817 |
License
NXAI Community License (see LICENSE
file)