Upload folder using huggingface_hub
Browse files- modeling_qllama.py +17 -17
modeling_qllama.py
CHANGED
@@ -176,7 +176,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
176 |
return q_embed, k_embed
|
177 |
|
178 |
|
179 |
-
class
|
180 |
def __init__(self, config):
|
181 |
super().__init__()
|
182 |
self.config = config
|
@@ -230,7 +230,7 @@ def eager_attention_forward(
|
|
230 |
return attn_output, attn_weights
|
231 |
|
232 |
|
233 |
-
class
|
234 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
235 |
|
236 |
def __init__(self, config: LlamaConfig, layer_idx: int):
|
@@ -306,14 +306,14 @@ class LlamaAttention(nn.Module):
|
|
306 |
return attn_output, attn_weights
|
307 |
|
308 |
|
309 |
-
class
|
310 |
def __init__(self, config: LlamaConfig, layer_idx: int):
|
311 |
super().__init__()
|
312 |
self.hidden_size = config.hidden_size
|
313 |
|
314 |
-
self.self_attn =
|
315 |
|
316 |
-
self.mlp =
|
317 |
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
318 |
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
319 |
|
@@ -381,11 +381,11 @@ LLAMA_START_DOCSTRING = r"""
|
|
381 |
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
382 |
LLAMA_START_DOCSTRING,
|
383 |
)
|
384 |
-
class
|
385 |
config_class = LlamaConfig
|
386 |
base_model_prefix = "model"
|
387 |
supports_gradient_checkpointing = True
|
388 |
-
_no_split_modules = ["
|
389 |
_skip_keys_device_placement = ["past_key_values"]
|
390 |
_supports_flash_attn_2 = True
|
391 |
_supports_sdpa = True
|
@@ -486,7 +486,7 @@ LLAMA_INPUTS_DOCSTRING = r"""
|
|
486 |
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
487 |
LLAMA_START_DOCSTRING,
|
488 |
)
|
489 |
-
class
|
490 |
"""
|
491 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
492 |
|
@@ -501,7 +501,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
501 |
|
502 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
503 |
self.layers = nn.ModuleList(
|
504 |
-
[
|
505 |
)
|
506 |
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
507 |
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
@@ -750,14 +750,14 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
750 |
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
751 |
|
752 |
|
753 |
-
class QLlamaForCausalLM(
|
754 |
_tied_weights_keys = ["lm_head.weight"]
|
755 |
_tp_plan = {"lm_head": "colwise_rep"}
|
756 |
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
757 |
|
758 |
def __init__(self, config):
|
759 |
super().__init__(config)
|
760 |
-
self.model =
|
761 |
self.vocab_size = config.vocab_size
|
762 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
763 |
|
@@ -890,11 +890,11 @@ class QLlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
|
|
890 |
""",
|
891 |
LLAMA_START_DOCSTRING,
|
892 |
)
|
893 |
-
class
|
894 |
def __init__(self, config):
|
895 |
super().__init__(config)
|
896 |
self.num_labels = config.num_labels
|
897 |
-
self.model =
|
898 |
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
899 |
|
900 |
# Initialize weights and apply final processing
|
@@ -989,13 +989,13 @@ SQuAD (a linear layer on top of the hidden-states output to compute `span start
|
|
989 |
""",
|
990 |
LLAMA_START_DOCSTRING,
|
991 |
)
|
992 |
-
class
|
993 |
base_model_prefix = "transformer"
|
994 |
|
995 |
# Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama
|
996 |
def __init__(self, config):
|
997 |
super().__init__(config)
|
998 |
-
self.transformer =
|
999 |
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
1000 |
|
1001 |
# Initialize weights and apply final processing
|
@@ -1076,11 +1076,11 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel):
|
|
1076 |
""",
|
1077 |
LLAMA_START_DOCSTRING,
|
1078 |
)
|
1079 |
-
class
|
1080 |
def __init__(self, config):
|
1081 |
super().__init__(config)
|
1082 |
self.num_labels = config.num_labels
|
1083 |
-
self.model =
|
1084 |
if getattr(config, "classifier_dropout", None) is not None:
|
1085 |
classifier_dropout = config.classifier_dropout
|
1086 |
elif getattr(config, "hidden_dropout", None) is not None:
|
|
|
176 |
return q_embed, k_embed
|
177 |
|
178 |
|
179 |
+
class QLlamaMLP(nn.Module):
|
180 |
def __init__(self, config):
|
181 |
super().__init__()
|
182 |
self.config = config
|
|
|
230 |
return attn_output, attn_weights
|
231 |
|
232 |
|
233 |
+
class QLlamaAttention(nn.Module):
|
234 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
235 |
|
236 |
def __init__(self, config: LlamaConfig, layer_idx: int):
|
|
|
306 |
return attn_output, attn_weights
|
307 |
|
308 |
|
309 |
+
class QLlamaDecoderLayer(nn.Module):
|
310 |
def __init__(self, config: LlamaConfig, layer_idx: int):
|
311 |
super().__init__()
|
312 |
self.hidden_size = config.hidden_size
|
313 |
|
314 |
+
self.self_attn = QLlamaAttention(config=config, layer_idx=layer_idx)
|
315 |
|
316 |
+
self.mlp = QLlamaMLP(config)
|
317 |
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
318 |
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
319 |
|
|
|
381 |
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
382 |
LLAMA_START_DOCSTRING,
|
383 |
)
|
384 |
+
class QLlamaPreTrainedModel(PreTrainedModel):
|
385 |
config_class = LlamaConfig
|
386 |
base_model_prefix = "model"
|
387 |
supports_gradient_checkpointing = True
|
388 |
+
_no_split_modules = ["QLlamaDecoderLayer"]
|
389 |
_skip_keys_device_placement = ["past_key_values"]
|
390 |
_supports_flash_attn_2 = True
|
391 |
_supports_sdpa = True
|
|
|
486 |
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
487 |
LLAMA_START_DOCSTRING,
|
488 |
)
|
489 |
+
class QLlamaModel(QLlamaPreTrainedModel):
|
490 |
"""
|
491 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
492 |
|
|
|
501 |
|
502 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
503 |
self.layers = nn.ModuleList(
|
504 |
+
[QLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
505 |
)
|
506 |
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
507 |
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
|
|
750 |
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
751 |
|
752 |
|
753 |
+
class QLlamaForCausalLM(QLlamaPreTrainedModel, GenerationMixin):
|
754 |
_tied_weights_keys = ["lm_head.weight"]
|
755 |
_tp_plan = {"lm_head": "colwise_rep"}
|
756 |
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
757 |
|
758 |
def __init__(self, config):
|
759 |
super().__init__(config)
|
760 |
+
self.model = QLlamaModel(config)
|
761 |
self.vocab_size = config.vocab_size
|
762 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
763 |
|
|
|
890 |
""",
|
891 |
LLAMA_START_DOCSTRING,
|
892 |
)
|
893 |
+
class QLlamaForSequenceClassification(QLlamaPreTrainedModel):
|
894 |
def __init__(self, config):
|
895 |
super().__init__(config)
|
896 |
self.num_labels = config.num_labels
|
897 |
+
self.model = QLlamaModel(config)
|
898 |
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
899 |
|
900 |
# Initialize weights and apply final processing
|
|
|
989 |
""",
|
990 |
LLAMA_START_DOCSTRING,
|
991 |
)
|
992 |
+
class QLlamaForQuestionAnswering(QLlamaPreTrainedModel):
|
993 |
base_model_prefix = "transformer"
|
994 |
|
995 |
# Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama
|
996 |
def __init__(self, config):
|
997 |
super().__init__(config)
|
998 |
+
self.transformer = QLlamaModel(config)
|
999 |
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
1000 |
|
1001 |
# Initialize weights and apply final processing
|
|
|
1076 |
""",
|
1077 |
LLAMA_START_DOCSTRING,
|
1078 |
)
|
1079 |
+
class QLlamaForTokenClassification(QLlamaPreTrainedModel):
|
1080 |
def __init__(self, config):
|
1081 |
super().__init__(config)
|
1082 |
self.num_labels = config.num_labels
|
1083 |
+
self.model = QLlamaModel(config)
|
1084 |
if getattr(config, "classifier_dropout", None) is not None:
|
1085 |
classifier_dropout = config.classifier_dropout
|
1086 |
elif getattr(config, "hidden_dropout", None) is not None:
|