Update modeling_chatglm.py
Browse files- modeling_chatglm.py +14 -5
modeling_chatglm.py
CHANGED
@@ -597,7 +597,7 @@ class GLMTransformer(torch.nn.Module):
|
|
597 |
if self.post_layer_norm:
|
598 |
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
599 |
# Final layer norm before output.
|
600 |
-
self.
|
601 |
dtype=config.torch_dtype)
|
602 |
|
603 |
self.gradient_checkpointing = False
|
@@ -653,7 +653,7 @@ class GLMTransformer(torch.nn.Module):
|
|
653 |
|
654 |
# Final layer norm.
|
655 |
if self.post_layer_norm:
|
656 |
-
hidden_states = self.
|
657 |
|
658 |
return hidden_states, presents, all_hidden_states, all_self_attentions
|
659 |
|
@@ -740,7 +740,14 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
740 |
init_kwargs = {}
|
741 |
if device is not None:
|
742 |
init_kwargs["device"] = device
|
743 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
744 |
self.num_layers = config.num_layers
|
745 |
self.multi_query_group_num = config.multi_query_group_num
|
746 |
self.kv_channels = config.kv_channels
|
@@ -765,7 +772,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
765 |
if self.post_layer_norm:
|
766 |
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
767 |
# Final layer norm before output.
|
768 |
-
self.
|
769 |
dtype=config.torch_dtype)
|
770 |
|
771 |
self.pre_seq_len = config.pre_seq_len
|
@@ -777,6 +784,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
777 |
self.prefix_encoder = PrefixEncoder(config)
|
778 |
self.dropout = torch.nn.Dropout(0.1)
|
779 |
|
|
|
|
|
780 |
def get_input_embeddings(self):
|
781 |
return self.embedding.word_embeddings
|
782 |
|
@@ -882,7 +891,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
882 |
|
883 |
# Final layer norm.
|
884 |
if self.post_layer_norm:
|
885 |
-
hidden_states = self.
|
886 |
|
887 |
if not return_dict:
|
888 |
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
|
|
597 |
if self.post_layer_norm:
|
598 |
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
599 |
# Final layer norm before output.
|
600 |
+
self.norm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
601 |
dtype=config.torch_dtype)
|
602 |
|
603 |
self.gradient_checkpointing = False
|
|
|
653 |
|
654 |
# Final layer norm.
|
655 |
if self.post_layer_norm:
|
656 |
+
hidden_states = self.norm(hidden_states)
|
657 |
|
658 |
return hidden_states, presents, all_hidden_states, all_self_attentions
|
659 |
|
|
|
740 |
init_kwargs = {}
|
741 |
if device is not None:
|
742 |
init_kwargs["device"] = device
|
743 |
+
|
744 |
+
self.embed_tokens = nn.Embedding(
|
745 |
+
config.padded_vocab_size,
|
746 |
+
self.hidden_size,
|
747 |
+
dtype=config.torch_dtype,
|
748 |
+
device=device
|
749 |
+
)
|
750 |
+
|
751 |
self.num_layers = config.num_layers
|
752 |
self.multi_query_group_num = config.multi_query_group_num
|
753 |
self.kv_channels = config.kv_channels
|
|
|
772 |
if self.post_layer_norm:
|
773 |
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
774 |
# Final layer norm before output.
|
775 |
+
self.norm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
776 |
dtype=config.torch_dtype)
|
777 |
|
778 |
self.pre_seq_len = config.pre_seq_len
|
|
|
784 |
self.prefix_encoder = PrefixEncoder(config)
|
785 |
self.dropout = torch.nn.Dropout(0.1)
|
786 |
|
787 |
+
self.gradient_checkpointing = False
|
788 |
+
|
789 |
def get_input_embeddings(self):
|
790 |
return self.embedding.word_embeddings
|
791 |
|
|
|
891 |
|
892 |
# Final layer norm.
|
893 |
if self.post_layer_norm:
|
894 |
+
hidden_states = self.norm(hidden_states)
|
895 |
|
896 |
if not return_dict:
|
897 |
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|