update
Browse files- README.md +93 -0
- modeling_bamboo.py +49 -46
README.md
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Introducation
|
2 |
+
|
3 |
+
Sparse computing is increasingly recognized as an important direction to improve the computational efficiency of large language models (LLM). For example, mixture of experts (MoE) methods show particular promise.
|
4 |
+
|
5 |
+
Recent studies ([Zhang el al., 2021](https://arxiv.org/abs/2110.01786); [Liu et al., 2023](https://openreview.net/pdf?id=wIPIhHd00i); [Mirzadeh et al., 2023](https://arxiv.org/abs/2310.04564)) reveal that LLMs inherently exhibit properties conducive to sparse computation when employing the ReLU activation function. This insight opens up new avenues for model efficiency, akin to MoE's selective activation. By dynamically choosing model parameters for computation, we can substantially boost efficiency.
|
6 |
+
|
7 |
+
However, the widespread adoption of ReLU-based models in the LLM field remains limited. Here we introduce a new 7B ReLU-based LLM, Bamboo(Github link:[https://github.com/SJTU-IPADS/Bamboo](https://github.com/SJTU-IPADS/Bamboo)), which boasts nearly 85% sparsity and performance levels on par with [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.1).
|
8 |
+
|
9 |
+
## Model Architecture
|
10 |
+
|
11 |
+
To push the model's sparsity, we add a ReLU component after GLU component, called dReLU(double ReLU). So our FFN network works as follows:
|
12 |
+
|
13 |
+
```Python
|
14 |
+
class BambooMLP(nn.Module):
|
15 |
+
def __init__(self, config):
|
16 |
+
super().__init__()
|
17 |
+
self.config = config
|
18 |
+
self.hidden_size = config.hidden_size
|
19 |
+
self.intermediate_size = config.intermediate_size
|
20 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
21 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
22 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
23 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.act_fn(self.up_proj(x)))
|
27 |
+
```
|
28 |
+
|
29 |
+
## Training Details
|
30 |
+
|
31 |
+
In this section, we introduce the details of training our model, including types of data used, and hyperparameters.
|
32 |
+
|
33 |
+
We initialized the model weights to Mistral's model weights and modified the FFN structure to the dReLU structure, then continued pre-training for 200B tokens, divided into two phases:
|
34 |
+
|
35 |
+
**First phase**: For the proportion of training corpus, we followed the data mix ratio and sources of the StableLM-3B model ([link](https://stability.wandb.io/stability-llm/stable-lm/reports/StableLM-3B-4E1T--VmlldzoyMjU4?accessToken=u3zujipenkx5g7rtcj9qojjgxpconyjktjkli2po09nffrffdhhchq045vp0wyfo)), conducting a further pre-training with 150B tokens.
|
36 |
+
|
37 |
+
The following table shows the hyper-paramters we used in our training process.
|
38 |
+
|
39 |
+
| Hyper-parameters | |
|
40 |
+
| --------------------- | ----------- |
|
41 |
+
| GPUs | 64 80G-A800 |
|
42 |
+
| Learning Rate Control | Cosine |
|
43 |
+
| Peak Learning Rate | 5e-5 |
|
44 |
+
| Batch Size | 4M |
|
45 |
+
| Weight Decay | 0.1 |
|
46 |
+
|
47 |
+
**Second phase**: We further adjusted the training corpus ratio, incorporating more domain-specific datasets (e.g., Math, Coding), and continued training for 50B tokens.
|
48 |
+
|
49 |
+
| Hyper-parameters | |
|
50 |
+
| --------------------- | ----------- |
|
51 |
+
| GPUs | 64 80G-A800 |
|
52 |
+
| Learning Rate Control | Cosine |
|
53 |
+
| Peak Learning Rate | 5e-6 |
|
54 |
+
| Batch Size | 4M |
|
55 |
+
| Weight Decay | 0.01 |
|
56 |
+
|
57 |
+
## Performance Evaluation Results
|
58 |
+
|
59 |
+
Our evaluation is based on the framework lm-evaluation-harness and opencompass. The evaluation details are listed as follows:
|
60 |
+
|
61 |
+
- Huggingface LLM Leaderboard tasks.
|
62 |
+
- Other Popular Benchmarks: We report the average accuracies on Big Bench Hard (BBH) (3-shot), HumanEval.
|
63 |
+
|
64 |
+
| | MMLU | Winogrande | TruthfulQA | Hellaswag | GSM8K | Arc-C | HumanEval | BBH | Average |
|
65 |
+
| ------- | ------ | ---------- | ---------- | --------- | ------ | ------ | --------- | ---- | ------- |
|
66 |
+
| Ours | 0.6389 | 0.7593 | 0.4406 | 0.8217 | 0.5315 | 0.6195 | 0.256 | | |
|
67 |
+
| Mistral | 0.6265 | 0.7924 | 0.4262 | 0.8332 | 0.4018 | 0.6143 | 0.2621 | | |
|
68 |
+
|
69 |
+
## Speed Evaluation Results
|
70 |
+
|
71 |
+
We utilize [PowerInfer](https://arxiv.org/pdf/2312.12456.pdf), a state-of-the-art acceleration framework leveraging activation sparsity. Here we show the inference speed compared with llama.cpp/transformers.
|
72 |
+
|
73 |
+
## Limitation & Disclaimer
|
74 |
+
|
75 |
+
- Bamboo, having undergone training with only 200B tokens, may still exhibit performance gaps in certain tasks.
|
76 |
+
- The Bamboo model has only been trained on English-language datasets, hence its capabilities in other languages are still lacking.
|
77 |
+
- The model may produce unexpected outputs due to its size and probabilistic generation paradigm.
|
78 |
+
|
79 |
+
## License
|
80 |
+
|
81 |
+
The code is licensed under Apache-2.0, while model weights are fully open for academic research and also allow **free** commercial usage.
|
82 |
+
|
83 |
+
## Citation:
|
84 |
+
|
85 |
+
Please kindly cite using the following BibTeX:
|
86 |
+
|
87 |
+
```
|
88 |
+
@misc{bamboo,
|
89 |
+
title={Bamboo: Harmonizing Sparsity and Performance in Large Language Models},
|
90 |
+
author={Yixin Song, Haotong Xie, Zeyu Mi, Haibo Chen},
|
91 |
+
year={2024}
|
92 |
+
}
|
93 |
+
```
|
modeling_bamboo.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
# coding=utf-8
|
2 |
# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
|
|
|
3 |
#
|
4 |
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5 |
# and OPT implementations in this library. It has been modified from its
|
@@ -72,11 +73,11 @@ def _get_unpad_data(attention_mask):
|
|
72 |
)
|
73 |
|
74 |
|
75 |
-
# Copied from transformers.models.
|
76 |
-
class
|
77 |
def __init__(self, hidden_size, eps=1e-6):
|
78 |
"""
|
79 |
-
|
80 |
"""
|
81 |
super().__init__()
|
82 |
self.weight = nn.Parameter(torch.ones(hidden_size))
|
@@ -91,8 +92,9 @@ class MistralRMSNorm(nn.Module):
|
|
91 |
|
92 |
|
93 |
# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral
|
|
|
94 |
# TODO @Arthur no longer copied from LLama after static cache
|
95 |
-
class
|
96 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
97 |
super().__init__()
|
98 |
|
@@ -166,7 +168,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
|
166 |
return q_embed, k_embed
|
167 |
|
168 |
|
169 |
-
class
|
170 |
def __init__(self, config):
|
171 |
super().__init__()
|
172 |
self.config = config
|
@@ -194,7 +196,8 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
194 |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
195 |
|
196 |
|
197 |
-
|
|
|
198 |
"""
|
199 |
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
200 |
and "Generating Long Sequences with Sparse Transformers".
|
@@ -231,7 +234,7 @@ class MistralAttention(nn.Module):
|
|
231 |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
232 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
233 |
|
234 |
-
self.rotary_emb =
|
235 |
self.head_dim,
|
236 |
max_position_embeddings=self.max_position_embeddings,
|
237 |
base=self.rope_theta,
|
@@ -322,9 +325,9 @@ class MistralAttention(nn.Module):
|
|
322 |
return attn_output, attn_weights, past_key_value
|
323 |
|
324 |
|
325 |
-
class
|
326 |
"""
|
327 |
-
|
328 |
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
329 |
flash attention and deal with padding tokens in case the input contains any of them.
|
330 |
"""
|
@@ -618,14 +621,14 @@ class MistralFlashAttention2(MistralAttention):
|
|
618 |
|
619 |
# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral
|
620 |
# TODO @Arthur no longer copied from LLama after static cache
|
621 |
-
class
|
622 |
"""
|
623 |
-
|
624 |
-
`
|
625 |
SDPA API.
|
626 |
"""
|
627 |
|
628 |
-
# Adapted from
|
629 |
def forward(
|
630 |
self,
|
631 |
hidden_states: torch.Tensor,
|
@@ -638,7 +641,7 @@ class MistralSdpaAttention(MistralAttention):
|
|
638 |
if output_attentions:
|
639 |
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
640 |
logger.warning_once(
|
641 |
-
"
|
642 |
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
643 |
)
|
644 |
return super().forward(
|
@@ -705,23 +708,23 @@ class MistralSdpaAttention(MistralAttention):
|
|
705 |
return attn_output, None, past_key_value
|
706 |
|
707 |
|
708 |
-
|
709 |
-
"eager":
|
710 |
-
"flash_attention_2":
|
711 |
-
"sdpa":
|
712 |
}
|
713 |
|
714 |
|
715 |
-
class
|
716 |
def __init__(self, config: BambooConfig, layer_idx: int):
|
717 |
super().__init__()
|
718 |
self.hidden_size = config.hidden_size
|
719 |
|
720 |
-
self.self_attn =
|
721 |
|
722 |
-
self.mlp =
|
723 |
-
self.input_layernorm =
|
724 |
-
self.post_attention_layernorm =
|
725 |
|
726 |
def forward(
|
727 |
self,
|
@@ -783,7 +786,7 @@ class MistralDecoderLayer(nn.Module):
|
|
783 |
return outputs
|
784 |
|
785 |
|
786 |
-
|
787 |
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
788 |
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
789 |
etc.)
|
@@ -801,14 +804,14 @@ MISTRAL_START_DOCSTRING = r"""
|
|
801 |
|
802 |
|
803 |
@add_start_docstrings(
|
804 |
-
"The bare
|
805 |
-
|
806 |
)
|
807 |
-
class
|
808 |
config_class = BambooConfig
|
809 |
base_model_prefix = "model"
|
810 |
supports_gradient_checkpointing = True
|
811 |
-
_no_split_modules = ["
|
812 |
_skip_keys_device_placement = "past_key_values"
|
813 |
_supports_flash_attn_2 = True
|
814 |
_supports_sdpa = True
|
@@ -826,7 +829,7 @@ class MistralPreTrainedModel(PreTrainedModel):
|
|
826 |
module.weight.data[module.padding_idx].zero_()
|
827 |
|
828 |
|
829 |
-
|
830 |
Args:
|
831 |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
832 |
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
@@ -897,12 +900,12 @@ MISTRAL_INPUTS_DOCSTRING = r"""
|
|
897 |
|
898 |
|
899 |
@add_start_docstrings(
|
900 |
-
"The bare
|
901 |
-
|
902 |
)
|
903 |
-
class
|
904 |
"""
|
905 |
-
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`
|
906 |
|
907 |
Args:
|
908 |
config: BambooConfig
|
@@ -915,10 +918,10 @@ class MistralModel(MistralPreTrainedModel):
|
|
915 |
|
916 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
917 |
self.layers = nn.ModuleList(
|
918 |
-
[
|
919 |
)
|
920 |
self._attn_implementation = config._attn_implementation
|
921 |
-
self.norm =
|
922 |
|
923 |
self.gradient_checkpointing = False
|
924 |
# Initialize weights and apply final processing
|
@@ -930,7 +933,7 @@ class MistralModel(MistralPreTrainedModel):
|
|
930 |
def set_input_embeddings(self, value):
|
931 |
self.embed_tokens = value
|
932 |
|
933 |
-
@add_start_docstrings_to_model_forward(
|
934 |
def forward(
|
935 |
self,
|
936 |
input_ids: torch.LongTensor = None,
|
@@ -993,7 +996,7 @@ class MistralModel(MistralPreTrainedModel):
|
|
993 |
if is_padding_right:
|
994 |
raise ValueError(
|
995 |
"You are attempting to perform batched generation with padding_side='right'"
|
996 |
-
" this may lead to unexpected behaviour for Flash Attention version of
|
997 |
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
998 |
)
|
999 |
|
@@ -1078,12 +1081,12 @@ class MistralModel(MistralPreTrainedModel):
|
|
1078 |
)
|
1079 |
|
1080 |
|
1081 |
-
class BambooForCausalLM(
|
1082 |
_tied_weights_keys = ["lm_head.weight"]
|
1083 |
|
1084 |
def __init__(self, config):
|
1085 |
super().__init__(config)
|
1086 |
-
self.model =
|
1087 |
self.vocab_size = config.vocab_size
|
1088 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
1089 |
|
@@ -1108,7 +1111,7 @@ class BambooForCausalLM(MistralPreTrainedModel):
|
|
1108 |
def get_decoder(self):
|
1109 |
return self.model
|
1110 |
|
1111 |
-
@add_start_docstrings_to_model_forward(
|
1112 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1113 |
def forward(
|
1114 |
self,
|
@@ -1266,9 +1269,9 @@ class BambooForCausalLM(MistralPreTrainedModel):
|
|
1266 |
|
1267 |
@add_start_docstrings(
|
1268 |
"""
|
1269 |
-
The
|
1270 |
|
1271 |
-
[`
|
1272 |
(e.g. GPT-2) do.
|
1273 |
|
1274 |
Since it does classification on the last token, it requires to know the position of the last token. If a
|
@@ -1277,14 +1280,14 @@ class BambooForCausalLM(MistralPreTrainedModel):
|
|
1277 |
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
1278 |
each row of the batch).
|
1279 |
""",
|
1280 |
-
|
1281 |
)
|
1282 |
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mistral, LLAMA->MISTRAL
|
1283 |
-
class
|
1284 |
def __init__(self, config):
|
1285 |
super().__init__(config)
|
1286 |
self.num_labels = config.num_labels
|
1287 |
-
self.model =
|
1288 |
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
1289 |
|
1290 |
# Initialize weights and apply final processing
|
@@ -1296,7 +1299,7 @@ class MistralForSequenceClassification(MistralPreTrainedModel):
|
|
1296 |
def set_input_embeddings(self, value):
|
1297 |
self.model.embed_tokens = value
|
1298 |
|
1299 |
-
@add_start_docstrings_to_model_forward(
|
1300 |
def forward(
|
1301 |
self,
|
1302 |
input_ids: torch.LongTensor = None,
|
|
|
1 |
# coding=utf-8
|
2 |
# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
# Copyright 2024 SJTU-IPADS AI and the HuggingFace Inc. team. All rights reserved.
|
4 |
#
|
5 |
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
6 |
# and OPT implementations in this library. It has been modified from its
|
|
|
73 |
)
|
74 |
|
75 |
|
76 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralRMSNorm with Mistral->Bamboo
|
77 |
+
class BambooRMSNorm(nn.Module):
|
78 |
def __init__(self, hidden_size, eps=1e-6):
|
79 |
"""
|
80 |
+
BambooRMSNorm is equivalent to T5LayerNorm
|
81 |
"""
|
82 |
super().__init__()
|
83 |
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
|
92 |
|
93 |
|
94 |
# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral
|
95 |
+
# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Bamboo
|
96 |
# TODO @Arthur no longer copied from LLama after static cache
|
97 |
+
class BambooRotaryEmbedding(nn.Module):
|
98 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
99 |
super().__init__()
|
100 |
|
|
|
168 |
return q_embed, k_embed
|
169 |
|
170 |
|
171 |
+
class BambooMLP(nn.Module):
|
172 |
def __init__(self, config):
|
173 |
super().__init__()
|
174 |
self.config = config
|
|
|
196 |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
197 |
|
198 |
|
199 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralAttention
|
200 |
+
class BambooAttention(nn.Module):
|
201 |
"""
|
202 |
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
203 |
and "Generating Long Sequences with Sparse Transformers".
|
|
|
234 |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
235 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
236 |
|
237 |
+
self.rotary_emb = BambooRotaryEmbedding(
|
238 |
self.head_dim,
|
239 |
max_position_embeddings=self.max_position_embeddings,
|
240 |
base=self.rope_theta,
|
|
|
325 |
return attn_output, attn_weights, past_key_value
|
326 |
|
327 |
|
328 |
+
class BambooFlashAttention2(BambooAttention):
|
329 |
"""
|
330 |
+
BAMBOO flash attention module. This module inherits from `BambooAttention` as the weights of the module stays
|
331 |
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
332 |
flash attention and deal with padding tokens in case the input contains any of them.
|
333 |
"""
|
|
|
621 |
|
622 |
# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral
|
623 |
# TODO @Arthur no longer copied from LLama after static cache
|
624 |
+
class BambooSdpaAttention(BambooAttention):
|
625 |
"""
|
626 |
+
Bamboo attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
627 |
+
`BambooAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
628 |
SDPA API.
|
629 |
"""
|
630 |
|
631 |
+
# Adapted from BambooAttention.forward
|
632 |
def forward(
|
633 |
self,
|
634 |
hidden_states: torch.Tensor,
|
|
|
641 |
if output_attentions:
|
642 |
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
643 |
logger.warning_once(
|
644 |
+
"BambooModel is using BambooSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
645 |
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
646 |
)
|
647 |
return super().forward(
|
|
|
708 |
return attn_output, None, past_key_value
|
709 |
|
710 |
|
711 |
+
BAMBOO_ATTENTION_CLASSES = {
|
712 |
+
"eager": BambooAttention,
|
713 |
+
"flash_attention_2": BambooFlashAttention2,
|
714 |
+
"sdpa": BambooSdpaAttention,
|
715 |
}
|
716 |
|
717 |
|
718 |
+
class BambooDecoderLayer(nn.Module):
|
719 |
def __init__(self, config: BambooConfig, layer_idx: int):
|
720 |
super().__init__()
|
721 |
self.hidden_size = config.hidden_size
|
722 |
|
723 |
+
self.self_attn = BAMBOO_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
724 |
|
725 |
+
self.mlp = BambooMLP(config)
|
726 |
+
self.input_layernorm = BambooRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
727 |
+
self.post_attention_layernorm = BambooRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
728 |
|
729 |
def forward(
|
730 |
self,
|
|
|
786 |
return outputs
|
787 |
|
788 |
|
789 |
+
BAMBOO_START_DOCSTRING = r"""
|
790 |
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
791 |
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
792 |
etc.)
|
|
|
804 |
|
805 |
|
806 |
@add_start_docstrings(
|
807 |
+
"The bare Bamboo Model outputting raw hidden-states without any specific head on top.",
|
808 |
+
BAMBOO_START_DOCSTRING,
|
809 |
)
|
810 |
+
class BambooPreTrainedModel(PreTrainedModel):
|
811 |
config_class = BambooConfig
|
812 |
base_model_prefix = "model"
|
813 |
supports_gradient_checkpointing = True
|
814 |
+
_no_split_modules = ["BambooDecoderLayer"]
|
815 |
_skip_keys_device_placement = "past_key_values"
|
816 |
_supports_flash_attn_2 = True
|
817 |
_supports_sdpa = True
|
|
|
829 |
module.weight.data[module.padding_idx].zero_()
|
830 |
|
831 |
|
832 |
+
BAMBOO_INPUTS_DOCSTRING = r"""
|
833 |
Args:
|
834 |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
835 |
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
|
|
900 |
|
901 |
|
902 |
@add_start_docstrings(
|
903 |
+
"The bare Bamboo Model outputting raw hidden-states without any specific head on top.",
|
904 |
+
BAMBOO_START_DOCSTRING,
|
905 |
)
|
906 |
+
class BambooModel(BambooPreTrainedModel):
|
907 |
"""
|
908 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`BambooDecoderLayer`]
|
909 |
|
910 |
Args:
|
911 |
config: BambooConfig
|
|
|
918 |
|
919 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
920 |
self.layers = nn.ModuleList(
|
921 |
+
[BambooDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
922 |
)
|
923 |
self._attn_implementation = config._attn_implementation
|
924 |
+
self.norm = BambooRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
925 |
|
926 |
self.gradient_checkpointing = False
|
927 |
# Initialize weights and apply final processing
|
|
|
933 |
def set_input_embeddings(self, value):
|
934 |
self.embed_tokens = value
|
935 |
|
936 |
+
@add_start_docstrings_to_model_forward(BAMBOO_INPUTS_DOCSTRING)
|
937 |
def forward(
|
938 |
self,
|
939 |
input_ids: torch.LongTensor = None,
|
|
|
996 |
if is_padding_right:
|
997 |
raise ValueError(
|
998 |
"You are attempting to perform batched generation with padding_side='right'"
|
999 |
+
" this may lead to unexpected behaviour for Flash Attention version of Bamboo. Make sure to "
|
1000 |
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
1001 |
)
|
1002 |
|
|
|
1081 |
)
|
1082 |
|
1083 |
|
1084 |
+
class BambooForCausalLM(BambooPreTrainedModel):
|
1085 |
_tied_weights_keys = ["lm_head.weight"]
|
1086 |
|
1087 |
def __init__(self, config):
|
1088 |
super().__init__(config)
|
1089 |
+
self.model = BambooModel(config)
|
1090 |
self.vocab_size = config.vocab_size
|
1091 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
1092 |
|
|
|
1111 |
def get_decoder(self):
|
1112 |
return self.model
|
1113 |
|
1114 |
+
@add_start_docstrings_to_model_forward(BAMBOO_INPUTS_DOCSTRING)
|
1115 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1116 |
def forward(
|
1117 |
self,
|
|
|
1269 |
|
1270 |
@add_start_docstrings(
|
1271 |
"""
|
1272 |
+
The Bamboo Model transformer with a sequence classification head on top (linear layer).
|
1273 |
|
1274 |
+
[`BambooForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
1275 |
(e.g. GPT-2) do.
|
1276 |
|
1277 |
Since it does classification on the last token, it requires to know the position of the last token. If a
|
|
|
1280 |
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
1281 |
each row of the batch).
|
1282 |
""",
|
1283 |
+
BAMBOO_START_DOCSTRING,
|
1284 |
)
|
1285 |
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mistral, LLAMA->MISTRAL
|
1286 |
+
class BambooForSequenceClassification(BambooPreTrainedModel):
|
1287 |
def __init__(self, config):
|
1288 |
super().__init__(config)
|
1289 |
self.num_labels = config.num_labels
|
1290 |
+
self.model = BambooModel(config)
|
1291 |
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
1292 |
|
1293 |
# Initialize weights and apply final processing
|
|
|
1299 |
def set_input_embeddings(self, value):
|
1300 |
self.model.embed_tokens = value
|
1301 |
|
1302 |
+
@add_start_docstrings_to_model_forward(BAMBOO_INPUTS_DOCSTRING)
|
1303 |
def forward(
|
1304 |
self,
|
1305 |
input_ids: torch.LongTensor = None,
|