File size: 5,534 Bytes
b4d8e93
1ffa590
b4d8e93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
## LLaMA-8x265M-MoE

[💻 Code](https://github.com/JuncaiL/SpecMoE/)

👋 Very nice to meet you here~

❤️ This repo contains the model `LLaMA-8x265M-MoE`(970M totally), which activates 2 out of 8 experts (332M parameters). This model is trained from scratch with FP32 precision. We firstly train the model through wikipedia dataset with 1 epoch and then through 10% of C4 dataset (10 data shards among 1024 data shards) with 1 epoch. This is NOT fine-tuned by instruction pairs, so it may not be good enough to act like a chatbot. 

📢 This series also includes a dense version (without MoE structure), see  [🤗this repo](https://huggingface.co/JuncaiL/llama-265m).



### 1. 🚀QuickStart

```python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_dir = "JuncaiL/llama-8x265m-moe"
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True)
model.eval()
model.to("cuda:0")

input_text = "Beijing is a famous city"
inputs = tokenizer(input_text, return_tensors="pt",return_token_type_ids=False)
inputs = inputs.to("cuda:0")

pred = model.generate(**inputs, max_length=50, temperature=0.0)
print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))
# Beijing is a famous city in China. It is the capital of the Beijing Province and the largest city in China. It is also the home of the world’s largest city, Beijing.
#The city is the
```



### 2.  📑Checkpoint Details and Evaluation

**Model Parameter**

| Model               | #Experts | #Activated Experts | #Params | # Activated Params | Flops(T) per sample (se q=2048) | Model Weights                                                |
| ------------------- | -------- | ------------------ | ------- | ------------------ | --------------------------------- | ------------------------------------------------------------ |
| 265M                | -        | -                  | 265M    | 265M               | 0.48                              | [🤗 llama-265m](https://huggingface.co/JuncaiL/llama-265m)    |
| 8 $\times$ 265M MoE | 2        | 8                  | 970M    | 332M               | 0.76                              | [🤗 llama-8x265m-moe](https://huggingface.co/JuncaiL/llama-8x265m-moe) |
| llama-7b            | -        | -                  | 7B      | 7B                 | 25.29                             |                                                              |

**Model Evaluation**

We use the "Average number of tokens verified" $N$ ( see reference  [link](https://arxiv.org/abs/2305.09781) ) as the metric to evaluate these models. This metric demonstrates that giving the same input to the small speculative model and llama-7b, counting from the first predicted tokens, how many successive tokens in the output sentence of the small speculative model are the same as the output sentence of the  llama-7b.

- **Average number of tokens verified**

| Dataset                               | 8 $\times$ 265M MoE | GPT without MoE |
| ------------------------------------- | ------------------- | --------------- |
| tatsu-lab/alpaca                      | 3.2362              | 3.0334          |
| alespalla/chatbot_instruction_prompts | 3.2031              | 3.0823          |
| web_questions                         | 2.7201              | 2.5541          |
| MohamedRashad/ChatGPT-prompts         | 3.0954              | 2.9768          |

Supposed that the small speculative model can have a hit rate $p$ for the next token when giving the same input.  Then we have

$$ 1p + 2p^2 + 3p^3 + ... =  N $$

We can get the hit rate as follow.

$$ p = 1 + \frac{1-\sqrt{1+4N}}{2N}$$

- **Hit Rate**

| Dataset                               | 8 $\times$ 265M MoE | GPT without MoE |
| ------------------------------------- | ------------------- | --------------- |
| tatsu-lab/alpaca                      | 0.578               | 0.567           |
| alespalla/chatbot_instruction_prompts | 0.576               | 0.570           |
| web_questions                         | 0.550               | 0.540           |
| MohamedRashad/ChatGPT-prompts         | 0.571               | 0.565           |



### 3. 🚧Limitation and Future Plans

For the MoE model, we only show the accuracy of how this small speculative model approximates the performance of llama-7b. In practice, to achieve physically low latency, the implementation of our MoE needs to be improved. In this version, we calculate the result of MoE expert by expert (sequentially) , and we need to fuse the calculation of these experts.



### Acknowledgment

1. My implementation of MoE structure is based on the repo `https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-2_8`
2. My inspiration for Speculative Inference comes from the paper "SpecInfer: Accelerating Generative Large Language Model Serving with Tree-based Speculative Inference and Verification" ([link](https://arxiv.org/abs/2305.09781)) . I am very appreciative of the help and suggestions from the SpecInfer group. ❤️



### Citation

```
@misc{specmoe-2024,
  title={SpecMoE: Building A Speculative MoE Model To Accelerate Inference},
  author={Juncai Liu},
  year={2024},
  month={March},
  url={https://github.com/JuncaiL/SpecMoE/}
}
```



### Contact

If you have any interest or question about this project, please feel free to contact me.

`[email protected]` (before June 30, 2024) or `[email protected]` (After June 30, 2024)