Upload README.md
Browse files
README.md
CHANGED
@@ -1 +1,110 @@
|
|
|
|
1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## LLaMA-8x265M-MoE
|
2 |
|
3 |
+
[💻 Code](https://github.com/JuncaiL/SpecMoE/)
|
4 |
+
|
5 |
+
👋 Very nice to meet you here~
|
6 |
+
|
7 |
+
❤️ This repo contains the model `LLaMA-8x265M-MoE`(970M totally), which activates 2 out of 8 experts (332M parameters). This model is trained from scratch with FP32 precision. We firstly train the model through wikipedia dataset with 1 epoch and then through 10% of C4 dataset (10 data shards among 1024 data shards) with 1 epoch. This is NOT fine-tuned by instruction pairs, so it may not be good enough to act like a chatbot.
|
8 |
+
|
9 |
+
📢 This series also includes a dense version (without MoE structure), see [🤗this repo](https://huggingface.co/JuncaiL/llama-265m).
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
### 1. 🚀QuickStart
|
14 |
+
|
15 |
+
```python
|
16 |
+
import torch
|
17 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
18 |
+
|
19 |
+
model_dir = "JuncaiL/llama-8x265m-moe"
|
20 |
+
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
21 |
+
model = AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True)
|
22 |
+
model.eval()
|
23 |
+
model.to("cuda:0")
|
24 |
+
|
25 |
+
input_text = "Beijing is a famous city"
|
26 |
+
inputs = tokenizer(input_text, return_tensors="pt",return_token_type_ids=False)
|
27 |
+
inputs = inputs.to("cuda:0")
|
28 |
+
|
29 |
+
pred = model.generate(**inputs, max_length=50, temperature=0.0)
|
30 |
+
print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))
|
31 |
+
# Beijing is a famous city in China. It is the capital of the Beijing Province and the largest city in China. It is also the home of the world’s largest city, Beijing.
|
32 |
+
#The city is the
|
33 |
+
```
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
### 2. 📑Checkpoint Details and Evaluation
|
38 |
+
|
39 |
+
**Model Parameter**
|
40 |
+
|
41 |
+
| Model | #Experts | #Activated Experts | #Params | # Activated Params | Flops(T) per sample (se q=2048) | Model Weights |
|
42 |
+
| ------------------- | -------- | ------------------ | ------- | ------------------ | --------------------------------- | ------------------------------------------------------------ |
|
43 |
+
| 265M | - | - | 265M | 265M | 0.48 | [🤗 llama-265m](https://huggingface.co/JuncaiL/llama-265m) |
|
44 |
+
| 8 $\times$ 265M MoE | 2 | 8 | 970M | 332M | 0.76 | [🤗 llama-8x265m-moe](https://huggingface.co/JuncaiL/llama-8x265m-moe) |
|
45 |
+
| llama-7b | - | - | 7B | 7B | 25.29 | |
|
46 |
+
|
47 |
+
**Model Evaluation**
|
48 |
+
|
49 |
+
We use the "Average number of tokens verified" $N$ ( see reference [link](https://arxiv.org/abs/2305.09781) ) as the metric to evaluate these models. This metric demonstrates that giving the same input to the small speculative model and llama-7b, counting from the first predicted tokens, how many successive tokens in the output sentence of the small speculative model are the same as the output sentence of the llama-7b.
|
50 |
+
|
51 |
+
- **Average number of tokens verified**
|
52 |
+
|
53 |
+
| Dataset | 8 $\times$ 265M MoE | GPT without MoE |
|
54 |
+
| ------------------------------------- | ------------------- | --------------- |
|
55 |
+
| tatsu-lab/alpaca | 3.2362 | 3.0334 |
|
56 |
+
| alespalla/chatbot_instruction_prompts | 3.2031 | 3.0823 |
|
57 |
+
| web_questions | 2.7201 | 2.5541 |
|
58 |
+
| MohamedRashad/ChatGPT-prompts | 3.0954 | 2.9768 |
|
59 |
+
|
60 |
+
Supposed that the small speculative model can have a hit rate $p$ for the next token when giving the same input. Then we have
|
61 |
+
|
62 |
+
$$ 1p + 2p^2 + 3p^3 + ... = N $$
|
63 |
+
|
64 |
+
We can get the hit rate as follow.
|
65 |
+
|
66 |
+
$$ p = 1 + \frac{1-\sqrt{1+4N}}{2N}$$
|
67 |
+
|
68 |
+
- **Hit Rate**
|
69 |
+
|
70 |
+
| Dataset | 8 $\times$ 265M MoE | GPT without MoE |
|
71 |
+
| ------------------------------------- | ------------------- | --------------- |
|
72 |
+
| tatsu-lab/alpaca | 0.578 | 0.567 |
|
73 |
+
| alespalla/chatbot_instruction_prompts | 0.576 | 0.570 |
|
74 |
+
| web_questions | 0.550 | 0.540 |
|
75 |
+
| MohamedRashad/ChatGPT-prompts | 0.571 | 0.565 |
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
### 3. 🚧Limitation and Future Plans
|
80 |
+
|
81 |
+
For the MoE model, we only show the accuracy of how this small speculative model approximates the performance of llama-7b. In practice, to achieve physically low latency, the implementation of our MoE needs to be improved. In this version, we calculate the result of MoE expert by expert (sequentially) , and we need to fuse the calculation of these experts.
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
### Acknowledgment
|
86 |
+
|
87 |
+
1. My implementation of MoE structure is based on the repo `https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-2_8`
|
88 |
+
2. My inspiration for Speculative Inference comes from the paper "SpecInfer: Accelerating Generative Large Language Model Serving with Tree-based Speculative Inference and Verification" ([link](https://arxiv.org/abs/2305.09781)) . I am very appreciative of the help and suggestions from the SpecInfer group. ❤️
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
### Citation
|
93 |
+
|
94 |
+
```
|
95 |
+
@misc{specmoe-2024,
|
96 |
+
title={SpecMoE: Building A Speculative MoE Model To Accelerate Inference},
|
97 |
+
author={Juncai Liu},
|
98 |
+
year={2024},
|
99 |
+
month={March},
|
100 |
+
url={https://github.com/JuncaiL/SpecMoE/}
|
101 |
+
}
|
102 |
+
```
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
### Contact
|
107 |
+
|
108 |
+
If you have any interest or question about this project, please feel free to contact me.
|
109 |
+
|
110 |
+
`[email protected]` (before June 30, 2024) or `[email protected]` (After June 30, 2024)
|