ybelkada commited on
Commit
3ceaf1a
·
1 Parent(s): 61cdd28

Update modeling_chatglm.py

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +14 -5
modeling_chatglm.py CHANGED
@@ -597,7 +597,7 @@ class GLMTransformer(torch.nn.Module):
597
  if self.post_layer_norm:
598
  LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
599
  # Final layer norm before output.
600
- self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
601
  dtype=config.torch_dtype)
602
 
603
  self.gradient_checkpointing = False
@@ -653,7 +653,7 @@ class GLMTransformer(torch.nn.Module):
653
 
654
  # Final layer norm.
655
  if self.post_layer_norm:
656
- hidden_states = self.final_layernorm(hidden_states)
657
 
658
  return hidden_states, presents, all_hidden_states, all_self_attentions
659
 
@@ -740,7 +740,14 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
740
  init_kwargs = {}
741
  if device is not None:
742
  init_kwargs["device"] = device
743
- self.embedding = init_method(Embedding, config, **init_kwargs)
 
 
 
 
 
 
 
744
  self.num_layers = config.num_layers
745
  self.multi_query_group_num = config.multi_query_group_num
746
  self.kv_channels = config.kv_channels
@@ -765,7 +772,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
765
  if self.post_layer_norm:
766
  LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
767
  # Final layer norm before output.
768
- self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
769
  dtype=config.torch_dtype)
770
 
771
  self.pre_seq_len = config.pre_seq_len
@@ -777,6 +784,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
777
  self.prefix_encoder = PrefixEncoder(config)
778
  self.dropout = torch.nn.Dropout(0.1)
779
 
 
 
780
  def get_input_embeddings(self):
781
  return self.embedding.word_embeddings
782
 
@@ -882,7 +891,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
882
 
883
  # Final layer norm.
884
  if self.post_layer_norm:
885
- hidden_states = self.final_layernorm(hidden_states)
886
 
887
  if not return_dict:
888
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
 
597
  if self.post_layer_norm:
598
  LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
599
  # Final layer norm before output.
600
+ self.norm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
601
  dtype=config.torch_dtype)
602
 
603
  self.gradient_checkpointing = False
 
653
 
654
  # Final layer norm.
655
  if self.post_layer_norm:
656
+ hidden_states = self.norm(hidden_states)
657
 
658
  return hidden_states, presents, all_hidden_states, all_self_attentions
659
 
 
740
  init_kwargs = {}
741
  if device is not None:
742
  init_kwargs["device"] = device
743
+
744
+ self.embed_tokens = nn.Embedding(
745
+ config.padded_vocab_size,
746
+ self.hidden_size,
747
+ dtype=config.torch_dtype,
748
+ device=device
749
+ )
750
+
751
  self.num_layers = config.num_layers
752
  self.multi_query_group_num = config.multi_query_group_num
753
  self.kv_channels = config.kv_channels
 
772
  if self.post_layer_norm:
773
  LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
774
  # Final layer norm before output.
775
+ self.norm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
776
  dtype=config.torch_dtype)
777
 
778
  self.pre_seq_len = config.pre_seq_len
 
784
  self.prefix_encoder = PrefixEncoder(config)
785
  self.dropout = torch.nn.Dropout(0.1)
786
 
787
+ self.gradient_checkpointing = False
788
+
789
  def get_input_embeddings(self):
790
  return self.embedding.word_embeddings
791
 
 
891
 
892
  # Final layer norm.
893
  if self.post_layer_norm:
894
+ hidden_states = self.norm(hidden_states)
895
 
896
  if not return_dict:
897
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)