danielpark commited on
Commit
7885f9e
·
verified ·
1 Parent(s): 1ae4683

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +120 -2
README.md CHANGED
@@ -7,7 +7,7 @@ tags:
7
  - moe
8
  ---
9
 
10
-
11
 
12
  # A experts weights of [Jamba-v0.1](https://huggingface.co/ai21labs/Jamba-v0.1)
13
 
@@ -15,6 +15,124 @@ Required Weights for Follow-up Research
15
 
16
  The original model is **[AI21lab's Jamba-v0.1](https://huggingface.co/ai21labs/Jamba-v0.1)**, which requires an **A100 80GB GPU**. Unfortunately, this almonst was not available via Google Colab or cloud computing services. Thus, attempts were made to perform **MoE (Mixture of Experts) splitting**, using the following resources as a basis:
17
  - **Original Model:** [Jamba-v0.1](https://huggingface.co/ai21labs/Jamba-v0.1)
18
- - **MoE Layer Separation**: Consult [this script](https://github.com/TechxGenus/Jamba-utils/blob/main/dense_downcycling.py) and using [TechxGenus/Jamba-v0.1-9B](https://huggingface.co/TechxGenus/Jamba-v0.1-9B).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
 
20
  Check [ai21labs/Jamba-tiny-random](https://huggingface.co/ai21labs/Jamba-tiny-random), which has 128M parameters (instead of 52B), and is initialized with random weights and did not undergo any training.
 
7
  - moe
8
  ---
9
 
10
+ # Please refrain from using this model yet. It's not any weight at all.
11
 
12
  # A experts weights of [Jamba-v0.1](https://huggingface.co/ai21labs/Jamba-v0.1)
13
 
 
15
 
16
  The original model is **[AI21lab's Jamba-v0.1](https://huggingface.co/ai21labs/Jamba-v0.1)**, which requires an **A100 80GB GPU**. Unfortunately, this almonst was not available via Google Colab or cloud computing services. Thus, attempts were made to perform **MoE (Mixture of Experts) splitting**, using the following resources as a basis:
17
  - **Original Model:** [Jamba-v0.1](https://huggingface.co/ai21labs/Jamba-v0.1)
18
+ - **MoE Layer Separation**: Consult [this script](https://github.com/TechxGenus/Jamba-utils/blob/main/dense_downcycling.py) written by [@TechxGenusand](https://github.com/TechxGenusand) and use [TechxGenus/Jamba-v0.1-9B](https://huggingface.co/TechxGenus/Jamba-v0.1-9B).
19
+
20
+
21
+ ## Usage
22
+
23
+ The code used in **[AI21lab's Jamba-v0.1](https://huggingface.co/ai21labs/Jamba-v0.1)**.
24
+
25
+ ### Presequities
26
+
27
+ To use Jamba, ensure you have `transformers` version 4.40.0 or higher installed (version 4.39.0 or higher is required):
28
+ ```bash
29
+ pip install transformers>=4.40.0
30
+ ```
31
+
32
+ For optimized Mamba implementations, install `mamba-ssm` and `causal-conv1d`:
33
+ ```bash
34
+ pip install mamba-ssm causal-conv1d>=1.2.0
35
+ ```
36
+ Ensure the model is on a CUDA device.
37
+
38
+ You can run the model without optimized Mamba kernels, but it's **not** recommended due to significantly lower latencies. To do so, specify `use_mamba_kernels=False` when loading the model.
39
+
40
+ ### Run the model
41
+
42
+ ```python
43
+ from transformers import AutoModelForCausalLM, AutoTokenizer
44
+
45
+ model = AutoModelForCausalLM.from_pretrained("danielpark/asp-9b-inst-base")
46
+ tokenizer = AutoTokenizer.from_pretrained("danielpark/asp-9b-inst-base")
47
+
48
+ input_ids = tokenizer("In the recent Super Bowl LVIII,", return_tensors='pt').to(model.device)["input_ids"]
49
+
50
+ outputs = model.generate(input_ids, max_new_tokens=216)
51
+
52
+ print(tokenizer.batch_decode(outputs))
53
+ # ["In the recent Super Bowl LVIII, the Kansas City Chiefs emerged victorious, defeating the San Francisco 49ers in a thrilling overtime showdown. The game was a nail-biter, with both teams showcasing their skills and determination.\n\nThe Chiefs, led by their star quarterback Patrick Mahomes, displayed their offensive prowess, while the 49ers, led by their strong defense, put up a tough fight. The game went into overtime, with the Chiefs ultimately securing the win with a touchdown.\n\nThe victory marked the Chiefs' second Super Bowl win in four years, solidifying their status as one of the top teams in the NFL. The game was a testament to the skill and talent of both teams, and a thrilling end to the NFL season.\n\nThe Super Bowl is not just about the game itself, but also about the halftime show and the commercials. This year's halftime show featured a star-studded lineup, including Usher, Alicia Keys, and Lil Jon. The show was a spectacle of music and dance, with the performers delivering an energetic and entertaining performance.\n"]
54
+ ```
55
+
56
+ When using `transformers<4.40.0`, ensure `trust_remote_code=True` for running the new Jamba architecture.
57
+
58
+ <details>
59
+ <summary><strong>Loading the model in half precision</strong></summary>
60
+
61
+ The published checkpoint is saved in BF16. To load it into RAM in BF16/FP16, specify `torch_dtype`:
62
+
63
+ ```python
64
+ from transformers import AutoModelForCausalLM
65
+ import torch
66
+ model = AutoModelForCausalLM.from_pretrained("danielpark/asp-9b-inst-base",
67
+ torch_dtype=torch.bfloat16) # you can also use torch_dtype=torch.float16
68
+ ```
69
+
70
+ When using half precision, enable the [FlashAttention2](https://github.com/Dao-AILab/flash-attention) implementation of the Attention blocks. To use it, ensure the model is on a CUDA device. Since the model is too big to fit on a single 80GB GPU, parallelize it using [accelerate](https://huggingface.co/docs/accelerate/index):
71
+ ```python
72
+ from transformers import AutoModelForCausalLM
73
+ import torch
74
+ model = AutoModelForCausalLM.from_pretrained("danielpark/asp-9b-inst-base",
75
+ torch_dtype=torch.bfloat16,
76
+ attn_implementation="flash_attention_2",
77
+ device_map="auto")
78
+ ```
79
+
80
+ </details>
81
+ <details><summary><strong>Load the model in 8-bit</strong></summary>
82
+
83
+ **Using 8-bit precision, up to 140K sequence lengths can fit on a single 80GB GPU.** Quantize the model to 8-bit using [bitsandbytes](https://huggingface.co/docs/bitsandbytes/index). To exclude Mamba blocks from quantization to prevent model quality degradation:
84
+
85
+ ```python
86
+ from transformers import AutoModelForCausalLM, BitsAndBytesConfig
87
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True,
88
+ llm_int8_skip_modules=["mamba"])
89
+ model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
90
+ torch_dtype=torch.bfloat16,
91
+ attn_implementation="flash_attention_2",
92
+ quantization_config=quantization_config)
93
+ ```
94
+ </details>
95
+
96
+ ### Fine-tuning example
97
+
98
+ Jamba is a base model that can be fine-tuned for custom solutions (including for chat/instruct versions). Fine-tune it using any technique of your choice. Here's an example of fine-tuning with the [PEFT](https://huggingface.co/docs/peft/index) library:
99
+
100
+ ```python
101
+ from datasets import load_dataset
102
+ from trl import SFTTrainer
103
+ from peft import LoraConfig
104
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
105
+
106
+ tokenizer = AutoTokenizer.from_pretrained("danielpark/asp-9b-inst-base")
107
+ model = AutoModelForCausalLM.from_pretrained("danielpark/asp-9b-inst-base", device_map='auto')
108
+
109
+ dataset = load_dataset("Abirate/english_quotes", split="train")
110
+ training_args = TrainingArguments(
111
+ output_dir="./results",
112
+ num_train_epochs=3,
113
+ per_device_train_batch_size=4,
114
+ logging_dir='./logs',
115
+ logging_steps=10,
116
+ learning_rate=2e-3
117
+ )
118
+ lora_config = LoraConfig(
119
+ r=8,
120
+ target_modules=["embed_tokens", "x_proj", "in_proj", "out_proj"],
121
+ task_type="CAUSAL_LM",
122
+ bias="none"
123
+ )
124
+ trainer = SFTTrainer(
125
+ model=model,
126
+ tokenizer=tokenizer,
127
+ args=training_args,
128
+ peft_config=lora_config,
129
+ train_dataset=dataset,
130
+ dataset_text_field="quote",
131
+ )
132
+
133
+ trainer.train()
134
+ ```
135
+
136
 
137
+ ## Further
138
  Check [ai21labs/Jamba-tiny-random](https://huggingface.co/ai21labs/Jamba-tiny-random), which has 128M parameters (instead of 52B), and is initialized with random weights and did not undergo any training.