Install Zyphra's fork of mamba-ssm (https://github.com/Zyphra/mamba)
git clone https://github.com/Zyphra/mamba.git
cd mamba
- You need to install from source:
pip install .
orpip install -e .
for editable mode (if you want to make changes to mamba code).
Then one should be able to load any iteration using the following snippet (say iteration 10,000):
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
model = MambaLMHeadModel.from_pretrained("Zyphra/Mamba-370M", iteration=10_000, device="cuda")
If iteration is not specified, then the model from the root of the repository is loaded, which is the final iteration (610,351).
The model was trained using "EleutherAI/gpt-neox-20b"
tokenizer.
Here is a snippet for text generation:
import transformers, torch
tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
inp_ids = torch.as_tensor([tokenizer.encode("Hello! How are you?")]).to("cuda")
out_ids = model.generate(inp_ids, max_length=100)
print(tokenizer.decode(out_ids[0]))
- Downloads last month
- 60