Feature Extraction
Transformers
Safetensors
English
bamboo
custom_code
yixinsong commited on
Commit
058d800
·
1 Parent(s): 6c59240

update modeling file

Browse files
Files changed (1) hide show
  1. modeling_bamboo.py +49 -46
modeling_bamboo.py CHANGED
@@ -1,5 +1,6 @@
1
  # coding=utf-8
2
  # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
 
3
  #
4
  # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
  # and OPT implementations in this library. It has been modified from its
@@ -72,11 +73,11 @@ def _get_unpad_data(attention_mask):
72
  )
73
 
74
 
75
- # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral
76
- class MistralRMSNorm(nn.Module):
77
  def __init__(self, hidden_size, eps=1e-6):
78
  """
79
- MistralRMSNorm is equivalent to T5LayerNorm
80
  """
81
  super().__init__()
82
  self.weight = nn.Parameter(torch.ones(hidden_size))
@@ -91,8 +92,9 @@ class MistralRMSNorm(nn.Module):
91
 
92
 
93
  # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral
 
94
  # TODO @Arthur no longer copied from LLama after static cache
95
- class MistralRotaryEmbedding(nn.Module):
96
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
97
  super().__init__()
98
 
@@ -166,7 +168,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
166
  return q_embed, k_embed
167
 
168
 
169
- class MistralMLP(nn.Module):
170
  def __init__(self, config):
171
  super().__init__()
172
  self.config = config
@@ -194,7 +196,8 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
194
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
195
 
196
 
197
- class MistralAttention(nn.Module):
 
198
  """
199
  Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
200
  and "Generating Long Sequences with Sparse Transformers".
@@ -231,7 +234,7 @@ class MistralAttention(nn.Module):
231
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
232
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
233
 
234
- self.rotary_emb = MistralRotaryEmbedding(
235
  self.head_dim,
236
  max_position_embeddings=self.max_position_embeddings,
237
  base=self.rope_theta,
@@ -322,9 +325,9 @@ class MistralAttention(nn.Module):
322
  return attn_output, attn_weights, past_key_value
323
 
324
 
325
- class MistralFlashAttention2(MistralAttention):
326
  """
327
- Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays
328
  untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
329
  flash attention and deal with padding tokens in case the input contains any of them.
330
  """
@@ -618,14 +621,14 @@ class MistralFlashAttention2(MistralAttention):
618
 
619
  # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral
620
  # TODO @Arthur no longer copied from LLama after static cache
621
- class MistralSdpaAttention(MistralAttention):
622
  """
623
- Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
624
- `MistralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
625
  SDPA API.
626
  """
627
 
628
- # Adapted from MistralAttention.forward
629
  def forward(
630
  self,
631
  hidden_states: torch.Tensor,
@@ -638,7 +641,7 @@ class MistralSdpaAttention(MistralAttention):
638
  if output_attentions:
639
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
640
  logger.warning_once(
641
- "MistralModel is using MistralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
642
  'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
643
  )
644
  return super().forward(
@@ -705,23 +708,23 @@ class MistralSdpaAttention(MistralAttention):
705
  return attn_output, None, past_key_value
706
 
707
 
708
- MISTRAL_ATTENTION_CLASSES = {
709
- "eager": MistralAttention,
710
- "flash_attention_2": MistralFlashAttention2,
711
- "sdpa": MistralSdpaAttention,
712
  }
713
 
714
 
715
- class MistralDecoderLayer(nn.Module):
716
  def __init__(self, config: BambooConfig, layer_idx: int):
717
  super().__init__()
718
  self.hidden_size = config.hidden_size
719
 
720
- self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
721
 
722
- self.mlp = MistralMLP(config)
723
- self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
724
- self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
725
 
726
  def forward(
727
  self,
@@ -783,7 +786,7 @@ class MistralDecoderLayer(nn.Module):
783
  return outputs
784
 
785
 
786
- MISTRAL_START_DOCSTRING = r"""
787
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
788
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
789
  etc.)
@@ -801,14 +804,14 @@ MISTRAL_START_DOCSTRING = r"""
801
 
802
 
803
  @add_start_docstrings(
804
- "The bare Mistral Model outputting raw hidden-states without any specific head on top.",
805
- MISTRAL_START_DOCSTRING,
806
  )
807
- class MistralPreTrainedModel(PreTrainedModel):
808
  config_class = BambooConfig
809
  base_model_prefix = "model"
810
  supports_gradient_checkpointing = True
811
- _no_split_modules = ["MistralDecoderLayer"]
812
  _skip_keys_device_placement = "past_key_values"
813
  _supports_flash_attn_2 = True
814
  _supports_sdpa = True
@@ -826,7 +829,7 @@ class MistralPreTrainedModel(PreTrainedModel):
826
  module.weight.data[module.padding_idx].zero_()
827
 
828
 
829
- MISTRAL_INPUTS_DOCSTRING = r"""
830
  Args:
831
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
832
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
@@ -897,12 +900,12 @@ MISTRAL_INPUTS_DOCSTRING = r"""
897
 
898
 
899
  @add_start_docstrings(
900
- "The bare Mistral Model outputting raw hidden-states without any specific head on top.",
901
- MISTRAL_START_DOCSTRING,
902
  )
903
- class MistralModel(MistralPreTrainedModel):
904
  """
905
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`]
906
 
907
  Args:
908
  config: BambooConfig
@@ -915,10 +918,10 @@ class MistralModel(MistralPreTrainedModel):
915
 
916
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
917
  self.layers = nn.ModuleList(
918
- [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
919
  )
920
  self._attn_implementation = config._attn_implementation
921
- self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
922
 
923
  self.gradient_checkpointing = False
924
  # Initialize weights and apply final processing
@@ -930,7 +933,7 @@ class MistralModel(MistralPreTrainedModel):
930
  def set_input_embeddings(self, value):
931
  self.embed_tokens = value
932
 
933
- @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
934
  def forward(
935
  self,
936
  input_ids: torch.LongTensor = None,
@@ -993,7 +996,7 @@ class MistralModel(MistralPreTrainedModel):
993
  if is_padding_right:
994
  raise ValueError(
995
  "You are attempting to perform batched generation with padding_side='right'"
996
- " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
997
  " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
998
  )
999
 
@@ -1078,12 +1081,12 @@ class MistralModel(MistralPreTrainedModel):
1078
  )
1079
 
1080
 
1081
- class BambooForCausalLM(MistralPreTrainedModel):
1082
  _tied_weights_keys = ["lm_head.weight"]
1083
 
1084
  def __init__(self, config):
1085
  super().__init__(config)
1086
- self.model = MistralModel(config)
1087
  self.vocab_size = config.vocab_size
1088
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1089
 
@@ -1108,7 +1111,7 @@ class BambooForCausalLM(MistralPreTrainedModel):
1108
  def get_decoder(self):
1109
  return self.model
1110
 
1111
- @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
1112
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1113
  def forward(
1114
  self,
@@ -1266,9 +1269,9 @@ class BambooForCausalLM(MistralPreTrainedModel):
1266
 
1267
  @add_start_docstrings(
1268
  """
1269
- The Mistral Model transformer with a sequence classification head on top (linear layer).
1270
 
1271
- [`MistralForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1272
  (e.g. GPT-2) do.
1273
 
1274
  Since it does classification on the last token, it requires to know the position of the last token. If a
@@ -1277,14 +1280,14 @@ class BambooForCausalLM(MistralPreTrainedModel):
1277
  padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1278
  each row of the batch).
1279
  """,
1280
- MISTRAL_START_DOCSTRING,
1281
  )
1282
  # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mistral, LLAMA->MISTRAL
1283
- class MistralForSequenceClassification(MistralPreTrainedModel):
1284
  def __init__(self, config):
1285
  super().__init__(config)
1286
  self.num_labels = config.num_labels
1287
- self.model = MistralModel(config)
1288
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1289
 
1290
  # Initialize weights and apply final processing
@@ -1296,7 +1299,7 @@ class MistralForSequenceClassification(MistralPreTrainedModel):
1296
  def set_input_embeddings(self, value):
1297
  self.model.embed_tokens = value
1298
 
1299
- @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
1300
  def forward(
1301
  self,
1302
  input_ids: torch.LongTensor = None,
 
1
  # coding=utf-8
2
  # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
3
+ # Copyright 2024 SJTU-IPADS AI and the HuggingFace Inc. team. All rights reserved.
4
  #
5
  # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6
  # and OPT implementations in this library. It has been modified from its
 
73
  )
74
 
75
 
76
+ # Copied from transformers.models.mistral.modeling_mistral.MistralRMSNorm with Mistral->Bamboo
77
+ class BambooRMSNorm(nn.Module):
78
  def __init__(self, hidden_size, eps=1e-6):
79
  """
80
+ BambooRMSNorm is equivalent to T5LayerNorm
81
  """
82
  super().__init__()
83
  self.weight = nn.Parameter(torch.ones(hidden_size))
 
92
 
93
 
94
  # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral
95
+ # copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Bamboo
96
  # TODO @Arthur no longer copied from LLama after static cache
97
+ class BambooRotaryEmbedding(nn.Module):
98
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
99
  super().__init__()
100
 
 
168
  return q_embed, k_embed
169
 
170
 
171
+ class BambooMLP(nn.Module):
172
  def __init__(self, config):
173
  super().__init__()
174
  self.config = config
 
196
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
197
 
198
 
199
+ # Copied from transformers.models.mistral.modeling_mistral.MistralAttention
200
+ class BambooAttention(nn.Module):
201
  """
202
  Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
203
  and "Generating Long Sequences with Sparse Transformers".
 
234
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
235
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
236
 
237
+ self.rotary_emb = BambooRotaryEmbedding(
238
  self.head_dim,
239
  max_position_embeddings=self.max_position_embeddings,
240
  base=self.rope_theta,
 
325
  return attn_output, attn_weights, past_key_value
326
 
327
 
328
+ class BambooFlashAttention2(BambooAttention):
329
  """
330
+ BAMBOO flash attention module. This module inherits from `BambooAttention` as the weights of the module stays
331
  untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
332
  flash attention and deal with padding tokens in case the input contains any of them.
333
  """
 
621
 
622
  # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral
623
  # TODO @Arthur no longer copied from LLama after static cache
624
+ class BambooSdpaAttention(BambooAttention):
625
  """
626
+ Bamboo attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
627
+ `BambooAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
628
  SDPA API.
629
  """
630
 
631
+ # Adapted from BambooAttention.forward
632
  def forward(
633
  self,
634
  hidden_states: torch.Tensor,
 
641
  if output_attentions:
642
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
643
  logger.warning_once(
644
+ "BambooModel is using BambooSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
645
  'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
646
  )
647
  return super().forward(
 
708
  return attn_output, None, past_key_value
709
 
710
 
711
+ BAMBOO_ATTENTION_CLASSES = {
712
+ "eager": BambooAttention,
713
+ "flash_attention_2": BambooFlashAttention2,
714
+ "sdpa": BambooSdpaAttention,
715
  }
716
 
717
 
718
+ class BambooDecoderLayer(nn.Module):
719
  def __init__(self, config: BambooConfig, layer_idx: int):
720
  super().__init__()
721
  self.hidden_size = config.hidden_size
722
 
723
+ self.self_attn = BAMBOO_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
724
 
725
+ self.mlp = BambooMLP(config)
726
+ self.input_layernorm = BambooRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
727
+ self.post_attention_layernorm = BambooRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
728
 
729
  def forward(
730
  self,
 
786
  return outputs
787
 
788
 
789
+ BAMBOO_START_DOCSTRING = r"""
790
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
791
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
792
  etc.)
 
804
 
805
 
806
  @add_start_docstrings(
807
+ "The bare Bamboo Model outputting raw hidden-states without any specific head on top.",
808
+ BAMBOO_START_DOCSTRING,
809
  )
810
+ class BambooPreTrainedModel(PreTrainedModel):
811
  config_class = BambooConfig
812
  base_model_prefix = "model"
813
  supports_gradient_checkpointing = True
814
+ _no_split_modules = ["BambooDecoderLayer"]
815
  _skip_keys_device_placement = "past_key_values"
816
  _supports_flash_attn_2 = True
817
  _supports_sdpa = True
 
829
  module.weight.data[module.padding_idx].zero_()
830
 
831
 
832
+ BAMBOO_INPUTS_DOCSTRING = r"""
833
  Args:
834
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
835
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
 
900
 
901
 
902
  @add_start_docstrings(
903
+ "The bare Bamboo Model outputting raw hidden-states without any specific head on top.",
904
+ BAMBOO_START_DOCSTRING,
905
  )
906
+ class BambooModel(BambooPreTrainedModel):
907
  """
908
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`BambooDecoderLayer`]
909
 
910
  Args:
911
  config: BambooConfig
 
918
 
919
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
920
  self.layers = nn.ModuleList(
921
+ [BambooDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
922
  )
923
  self._attn_implementation = config._attn_implementation
924
+ self.norm = BambooRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
925
 
926
  self.gradient_checkpointing = False
927
  # Initialize weights and apply final processing
 
933
  def set_input_embeddings(self, value):
934
  self.embed_tokens = value
935
 
936
+ @add_start_docstrings_to_model_forward(BAMBOO_INPUTS_DOCSTRING)
937
  def forward(
938
  self,
939
  input_ids: torch.LongTensor = None,
 
996
  if is_padding_right:
997
  raise ValueError(
998
  "You are attempting to perform batched generation with padding_side='right'"
999
+ " this may lead to unexpected behaviour for Flash Attention version of Bamboo. Make sure to "
1000
  " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1001
  )
1002
 
 
1081
  )
1082
 
1083
 
1084
+ class BambooForCausalLM(BambooPreTrainedModel):
1085
  _tied_weights_keys = ["lm_head.weight"]
1086
 
1087
  def __init__(self, config):
1088
  super().__init__(config)
1089
+ self.model = BambooModel(config)
1090
  self.vocab_size = config.vocab_size
1091
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1092
 
 
1111
  def get_decoder(self):
1112
  return self.model
1113
 
1114
+ @add_start_docstrings_to_model_forward(BAMBOO_INPUTS_DOCSTRING)
1115
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1116
  def forward(
1117
  self,
 
1269
 
1270
  @add_start_docstrings(
1271
  """
1272
+ The Bamboo Model transformer with a sequence classification head on top (linear layer).
1273
 
1274
+ [`BambooForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1275
  (e.g. GPT-2) do.
1276
 
1277
  Since it does classification on the last token, it requires to know the position of the last token. If a
 
1280
  padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1281
  each row of the batch).
1282
  """,
1283
+ BAMBOO_START_DOCSTRING,
1284
  )
1285
  # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mistral, LLAMA->MISTRAL
1286
+ class BambooForSequenceClassification(BambooPreTrainedModel):
1287
  def __init__(self, config):
1288
  super().__init__(config)
1289
  self.num_labels = config.num_labels
1290
+ self.model = BambooModel(config)
1291
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1292
 
1293
  # Initialize weights and apply final processing
 
1299
  def set_input_embeddings(self, value):
1300
  self.model.embed_tokens = value
1301
 
1302
+ @add_start_docstrings_to_model_forward(BAMBOO_INPUTS_DOCSTRING)
1303
  def forward(
1304
  self,
1305
  input_ids: torch.LongTensor = None,