File size: 4,813 Bytes
1f387fa
 
 
630c127
afe204f
630c127
afe204f
81ef6f4
 
99fc66b
8dd106c
afe204f
 
 
 
7494a05
 
 
 
 
afe204f
7494a05
afe204f
 
 
 
7494a05
360d2b3
 
 
afe204f
7494a05
afe204f
 
 
 
 
 
 
 
2172818
afe204f
 
 
1970b61
 
 
1e05475
 
 
 
 
 
 
 
751ddef
 
 
 
 
 
5dd2d3c
751ddef
 
 
 
 
 
 
 
 
5dd2d3c
751ddef
 
b79cb15
751ddef
 
 
 
 
 
8dd106c
 
 
 
99fc66b
 
8dd106c
 
 
 
3f953fb
 
8dd106c
1970b61
 
a1c65f7
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
---
license: apache-2.0
---
# Model Card for Zamba 7B

Zamba-7B-v1 is a hybrid model between Mamba, a state-space model, and transformers. It uses a mamba backbone with a shared transformer layer every 6 blocks. Zamba was trained using next-token prediction. It uses the Mistral v0.1 tokenizer. We came to this architecture after a series of ablations at small scales. Zamba-7B-v1 was pre-trained on 1T tokens of text and code data sourced from open web-datasets. Subsequently in a second phase, Zamba was annealed on a mixture of 50B high-quality tokens.

Note: the current Huggingface implementation of Zamba performs slower than our internal implementation. We are working to fix this with the Huggingface team.

Our technical report describing the training of Zamba is available [here](https://arxiv.org/abs/2405.16712).

## Quick start

### Presequities

To download Zamba, clone Zyphra's fork of transformers:
1. `git clone https://github.com/Zyphra/transformers_zamba`
2. `cd transformers_zamba`
3. Install the repository: `pip install -e .`


In order to run optimized Mamba implementations on a CUDA device, you need to install `mamba-ssm` and `causal-conv1d`:
```bash
pip install mamba-ssm causal-conv1d>=1.2.0
```

You can run the model without using the optimized Mamba kernels, but it is **not** recommended as it will result in significantly higher latency. 

To run on CPU, please specify `use_mamba_kernels=False` when loading the model using ``AutoModelForCausalLM.from_pretrained``.


### Inference

```python
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba-7B-v1")
model = AutoModelForCausalLM.from_pretrained("Zyphra/Zamba-7B-v1", device_map="auto", torch_dtype=torch.bfloat16)

input_text = "What factors contributed to the fall of the Roman Empire?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids, max_new_tokens=100)
print(tokenizer.decode(outputs[0]))
```

To load a different checkpoint use,  e.g., for iteration 2500,

```python
model = AutoModelForCausalLM.from_pretrained("Zyphra/Zamba-7B-v1", device_map="auto", torch_dtype=torch.bfloat16, revision="iter2500")
```

The default iteration is the fully trained model, corresponding to iteration 25156. This is the number of training iterations done starting from Zamba-phase 1 [Zyphra/Zamba-7B-v1-phase1](https://huggingface.co/Zyphra/Zamba-7B-v1-phase1). See [arXiv:2405.16712](https://arxiv.org/abs/2405.16712) for more details on training.

## Model Details

Zamba utilizes a unique hybrid SSM architecture. This architecture consists of a backbone of Mamba layers interspersed with a shared attention layer. This attention has shared weights to minimize the parameter cost of the model. We find that concatenating the original model embeddings to the input to this attention block improves performance, likely due to better maintenance of information across depth. 


<center>
<img src="https://cdn-uploads.huggingface.co/production/uploads/65c05e75c084467acab2f84a/IGK562oVTFSOQbpLavu7E.png" width="300" alt="Zamba architecture">
</center>


## Performance

We find that Zamba performs significantly better than existing open models (with open datasets and training details) at this scale. However, it performs slightly worse than the leading open-weight models at the 7B scale. Most of this difference derives from MMLU and reasoning evaluations. Zamba, however, is trained on significantly fewer tokens than these models and is the most sample efficient model in terms of performance per training tokens.


<center>
<img src="https://cdn-uploads.huggingface.co/production/uploads/65c05e75c084467acab2f84a/FG73iXpiDGSX_opbDJxKo.png" width="700" alt="Zamba performance">
</center>


Due to its SSM architecture, Zamba is extremely efficient in inference, substantially outperforming comparable 7B and 8B models in inference latency as well as memory cost of generation due to its substantially diminished KV cache.

<center>
<img src="https://cdn-uploads.huggingface.co/production/uploads/65c05e75c084467acab2f84a/cghYPnDbdzweT1b2RyiXA.png" width="400" alt="Zamba performance">
</center>

## Citation

If you find Zamba useful in your work please cite it as:

```
@article{glorioso2024zamba,
  title={Zamba: A Compact 7B SSM Hybrid Model},
  author={Glorioso, Paolo and Anthony, Quentin and Tokpanov, Yury and Whittington, James and Pilault, Jonathan and Ibrahim, Adam and Millidge, Beren},
  journal={arXiv preprint arXiv:2405.16712},
  year={2024}
}
```

## Notice

Zamba is a pretrained base model and therefore does not have any moderation mechanism. In addition, one should not expect good chat performance, as this model was not fine-tuned for chat.

## Paper

arxiv.org/abs/2405.16712