--- license: other --- # xLSTM-7B This xLSTM-7B was pre-trained on the DCLM and selected high-quality data for in a total of approx. 2.3 T tokens using the `xlstm-jax` framework. ## How to use it First, install `xlstm`, which now uses the `mlstm_kernels` package for triton kernels: ```bash pip install xlstm pip install mlstm_kernels ``` For now, install the transformers repositiory fork from NX-AI (until it is merged): ```bash pip install 'transformers @ git+ssh://git@github.com/NX-AI/transformers.git@integrate_xlstm' ``` Use this model as: ```python from transformers import AutoModelForCausalLM, AutoTokenizer xlstm = AutoModelForCausalLM.from_pretrained("NX-AI/xLSTM-7b", device_map="auto") # this is a fork of EleutherAI/gpt-neox-20b tokenizer = AutoTokenizer.from_pretrained("NX-AI/xLSTM-7b") tokens = tokenizer("Hello xLSTM, how are you doing?", return_tensors='pt')['input_ids'].to(device="cuda") out = xlstm.generate(tokens, max_new_tokens=20) print(tokenizer.decode(out[0])) ``` If you cannot or do not want to use the triton kernels, you can change them to native PyTorch implementations: ```python xlstm_config = AutoConfig.from_pretrained("NX-AI/xLSTM-7b") xlstm_config.step_kernel = "native" xlstm_config.chunkwise_kernel = "chunkwise--native_autograd" xlstm_config.sequence_kernel = "native_sequence__native" xlstm = AutoModelForCausalLM.from_pretrained("NX-AI/xLSTM-7b", config=xlstm_config, device_map="auto") # verify selected kernels from pprint import pprint pprint(xlstm.backbone.blocks[0].mlstm_layer.config) ``` ## Speed results Generation Speed using `torch.cuda.graph` and `torch.compile` optimizations on one NVIDIA H100: ![generation speed](plot_tokens_per_sec.svg) ## Performance ![mmlu_train_token](MMLUvsTrainToken.svg) Using HuggingFace's `lm_eval`: | BBH | MMLU-Pro | Math | MUSR | GPQA | IfEval | |-------|----------|--------|------|------|--------| | 0.381 | 0.242 | 0.036 | 0.379|0.280 | 0.244 | Using HuggingFace's `lighteval` in the Leaderboard-v1 settings: |Arc-Challenge (25-shot) |MMLU (5-shot) |Hellaswag (10-shot)|Winogrande (5-shot) |TruthfulQA (0-shot) |GSM8k (5-shot) |OpenbookQA (5-shot) | PiQA (5-shot)| |------------------------|--------------|-------------------|--------------------|--------------------|---------------|--------------------|--------------| | 0.584 |0.589 | 0.710 |0.742 | 0.420 | 0.004 | 0.443 | 0.817 | ## License NXAI Community License (see `LICENSE` file)