semran1 commited on
Commit
9402361
·
verified ·
1 Parent(s): ddefc34

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_qllama.py +17 -17
modeling_qllama.py CHANGED
@@ -176,7 +176,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
176
  return q_embed, k_embed
177
 
178
 
179
- class LlamaMLP(nn.Module):
180
  def __init__(self, config):
181
  super().__init__()
182
  self.config = config
@@ -230,7 +230,7 @@ def eager_attention_forward(
230
  return attn_output, attn_weights
231
 
232
 
233
- class LlamaAttention(nn.Module):
234
  """Multi-headed attention from 'Attention Is All You Need' paper"""
235
 
236
  def __init__(self, config: LlamaConfig, layer_idx: int):
@@ -306,14 +306,14 @@ class LlamaAttention(nn.Module):
306
  return attn_output, attn_weights
307
 
308
 
309
- class LlamaDecoderLayer(nn.Module):
310
  def __init__(self, config: LlamaConfig, layer_idx: int):
311
  super().__init__()
312
  self.hidden_size = config.hidden_size
313
 
314
- self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
315
 
316
- self.mlp = LlamaMLP(config)
317
  self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
318
  self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
319
 
@@ -381,11 +381,11 @@ LLAMA_START_DOCSTRING = r"""
381
  "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
382
  LLAMA_START_DOCSTRING,
383
  )
384
- class LlamaPreTrainedModel(PreTrainedModel):
385
  config_class = LlamaConfig
386
  base_model_prefix = "model"
387
  supports_gradient_checkpointing = True
388
- _no_split_modules = ["LlamaDecoderLayer"]
389
  _skip_keys_device_placement = ["past_key_values"]
390
  _supports_flash_attn_2 = True
391
  _supports_sdpa = True
@@ -486,7 +486,7 @@ LLAMA_INPUTS_DOCSTRING = r"""
486
  "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
487
  LLAMA_START_DOCSTRING,
488
  )
489
- class LlamaModel(LlamaPreTrainedModel):
490
  """
491
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
492
 
@@ -501,7 +501,7 @@ class LlamaModel(LlamaPreTrainedModel):
501
 
502
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
503
  self.layers = nn.ModuleList(
504
- [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
505
  )
506
  self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
507
  self.rotary_emb = LlamaRotaryEmbedding(config=config)
@@ -750,14 +750,14 @@ class LlamaModel(LlamaPreTrainedModel):
750
  class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
751
 
752
 
753
- class QLlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
754
  _tied_weights_keys = ["lm_head.weight"]
755
  _tp_plan = {"lm_head": "colwise_rep"}
756
  _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
757
 
758
  def __init__(self, config):
759
  super().__init__(config)
760
- self.model = LlamaModel(config)
761
  self.vocab_size = config.vocab_size
762
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
763
 
@@ -890,11 +890,11 @@ class QLlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
890
  """,
891
  LLAMA_START_DOCSTRING,
892
  )
893
- class LlamaForSequenceClassification(LlamaPreTrainedModel):
894
  def __init__(self, config):
895
  super().__init__(config)
896
  self.num_labels = config.num_labels
897
- self.model = LlamaModel(config)
898
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
899
 
900
  # Initialize weights and apply final processing
@@ -989,13 +989,13 @@ SQuAD (a linear layer on top of the hidden-states output to compute `span start
989
  """,
990
  LLAMA_START_DOCSTRING,
991
  )
992
- class LlamaForQuestionAnswering(LlamaPreTrainedModel):
993
  base_model_prefix = "transformer"
994
 
995
  # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama
996
  def __init__(self, config):
997
  super().__init__(config)
998
- self.transformer = LlamaModel(config)
999
  self.qa_outputs = nn.Linear(config.hidden_size, 2)
1000
 
1001
  # Initialize weights and apply final processing
@@ -1076,11 +1076,11 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel):
1076
  """,
1077
  LLAMA_START_DOCSTRING,
1078
  )
1079
- class LlamaForTokenClassification(LlamaPreTrainedModel):
1080
  def __init__(self, config):
1081
  super().__init__(config)
1082
  self.num_labels = config.num_labels
1083
- self.model = LlamaModel(config)
1084
  if getattr(config, "classifier_dropout", None) is not None:
1085
  classifier_dropout = config.classifier_dropout
1086
  elif getattr(config, "hidden_dropout", None) is not None:
 
176
  return q_embed, k_embed
177
 
178
 
179
+ class QLlamaMLP(nn.Module):
180
  def __init__(self, config):
181
  super().__init__()
182
  self.config = config
 
230
  return attn_output, attn_weights
231
 
232
 
233
+ class QLlamaAttention(nn.Module):
234
  """Multi-headed attention from 'Attention Is All You Need' paper"""
235
 
236
  def __init__(self, config: LlamaConfig, layer_idx: int):
 
306
  return attn_output, attn_weights
307
 
308
 
309
+ class QLlamaDecoderLayer(nn.Module):
310
  def __init__(self, config: LlamaConfig, layer_idx: int):
311
  super().__init__()
312
  self.hidden_size = config.hidden_size
313
 
314
+ self.self_attn = QLlamaAttention(config=config, layer_idx=layer_idx)
315
 
316
+ self.mlp = QLlamaMLP(config)
317
  self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
318
  self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
319
 
 
381
  "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
382
  LLAMA_START_DOCSTRING,
383
  )
384
+ class QLlamaPreTrainedModel(PreTrainedModel):
385
  config_class = LlamaConfig
386
  base_model_prefix = "model"
387
  supports_gradient_checkpointing = True
388
+ _no_split_modules = ["QLlamaDecoderLayer"]
389
  _skip_keys_device_placement = ["past_key_values"]
390
  _supports_flash_attn_2 = True
391
  _supports_sdpa = True
 
486
  "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
487
  LLAMA_START_DOCSTRING,
488
  )
489
+ class QLlamaModel(QLlamaPreTrainedModel):
490
  """
491
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
492
 
 
501
 
502
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
503
  self.layers = nn.ModuleList(
504
+ [QLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
505
  )
506
  self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
507
  self.rotary_emb = LlamaRotaryEmbedding(config=config)
 
750
  class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
751
 
752
 
753
+ class QLlamaForCausalLM(QLlamaPreTrainedModel, GenerationMixin):
754
  _tied_weights_keys = ["lm_head.weight"]
755
  _tp_plan = {"lm_head": "colwise_rep"}
756
  _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
757
 
758
  def __init__(self, config):
759
  super().__init__(config)
760
+ self.model = QLlamaModel(config)
761
  self.vocab_size = config.vocab_size
762
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
763
 
 
890
  """,
891
  LLAMA_START_DOCSTRING,
892
  )
893
+ class QLlamaForSequenceClassification(QLlamaPreTrainedModel):
894
  def __init__(self, config):
895
  super().__init__(config)
896
  self.num_labels = config.num_labels
897
+ self.model = QLlamaModel(config)
898
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
899
 
900
  # Initialize weights and apply final processing
 
989
  """,
990
  LLAMA_START_DOCSTRING,
991
  )
992
+ class QLlamaForQuestionAnswering(QLlamaPreTrainedModel):
993
  base_model_prefix = "transformer"
994
 
995
  # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama
996
  def __init__(self, config):
997
  super().__init__(config)
998
+ self.transformer = QLlamaModel(config)
999
  self.qa_outputs = nn.Linear(config.hidden_size, 2)
1000
 
1001
  # Initialize weights and apply final processing
 
1076
  """,
1077
  LLAMA_START_DOCSTRING,
1078
  )
1079
+ class QLlamaForTokenClassification(QLlamaPreTrainedModel):
1080
  def __init__(self, config):
1081
  super().__init__(config)
1082
  self.num_labels = config.num_labels
1083
+ self.model = QLlamaModel(config)
1084
  if getattr(config, "classifier_dropout", None) is not None:
1085
  classifier_dropout = config.classifier_dropout
1086
  elif getattr(config, "hidden_dropout", None) is not None: