ybelkada commited on
Commit
3e08fb0
·
1 Parent(s): ad87787

Update modeling_chatglm.py

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +68 -16
modeling_chatglm.py CHANGED
@@ -753,9 +753,20 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
753
 
754
  self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
755
  dtype=config.torch_dtype)
756
- self.encoder = init_method(GLMTransformer, config, **init_kwargs)
757
- self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
758
- dtype=config.torch_dtype, **init_kwargs)
 
 
 
 
 
 
 
 
 
 
 
759
  self.pre_seq_len = config.pre_seq_len
760
  self.prefix_projection = config.prefix_projection
761
  if self.pre_seq_len is not None:
@@ -827,10 +838,50 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
827
  rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
828
 
829
  # Run encoder.
830
- hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
831
- inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
832
- kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
833
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
834
 
835
  if not return_dict:
836
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
@@ -844,7 +895,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
844
 
845
  def quantize(self, weight_bit_width: int):
846
  from .quantization import quantize
847
- quantize(self.encoder, weight_bit_width)
848
  return self
849
 
850
 
@@ -853,7 +904,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
853
  super().__init__(config)
854
 
855
  self.max_sequence_length = config.max_length
856
- self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
 
857
  self.config = config
858
  self.quantized = False
859
 
@@ -934,7 +986,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
934
  use_cache = use_cache if use_cache is not None else self.config.use_cache
935
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
936
 
937
- transformer_outputs = self.transformer(
938
  input_ids=input_ids,
939
  position_ids=position_ids,
940
  attention_mask=attention_mask,
@@ -948,7 +1000,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
948
  hidden_states = transformer_outputs[0]
949
  if return_last_logit:
950
  hidden_states = hidden_states[-1:]
951
- lm_logits = self.transformer.output_layer(hidden_states)
952
  lm_logits = lm_logits.transpose(0, 1).contiguous()
953
 
954
  loss = None
@@ -1062,8 +1114,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1062
  inputs = inputs.to(self.device)
1063
  if past_key_values is not None:
1064
  past_length = past_key_values[0][0].shape[0]
1065
- if self.transformer.pre_seq_len is not None:
1066
- past_length -= self.transformer.pre_seq_len
1067
  inputs.position_ids += past_length
1068
  attention_mask = inputs.attention_mask
1069
  attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
@@ -1205,7 +1257,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1205
 
1206
  self.config.quantization_bit = bits
1207
 
1208
- self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device,
1209
  **kwargs)
1210
  return self
1211
 
@@ -1215,7 +1267,7 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1215
  super().__init__(config)
1216
 
1217
  self.num_labels = config.num_labels
1218
- self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
1219
 
1220
  self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
1221
  if config.classifier_dropout is not None:
@@ -1242,7 +1294,7 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1242
  ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
1243
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1244
 
1245
- transformer_outputs = self.transformer(
1246
  input_ids=input_ids,
1247
  position_ids=position_ids,
1248
  attention_mask=attention_mask,
 
753
 
754
  self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
755
  dtype=config.torch_dtype)
756
+
757
+ # Transformer layers.
758
+ def build_layer(layer_number):
759
+ return GLMBlock(config, layer_number, device=device)
760
+
761
+ self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
762
+ self.num_layers = config.num_layers
763
+
764
+ if self.post_layer_norm:
765
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
766
+ # Final layer norm before output.
767
+ self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
768
+ dtype=config.torch_dtype)
769
+
770
  self.pre_seq_len = config.pre_seq_len
771
  self.prefix_projection = config.prefix_projection
772
  if self.pre_seq_len is not None:
 
838
  rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
839
 
840
  # Run encoder.
841
+ if not past_key_values:
842
+ past_key_values = [None for _ in range(self.num_layers)]
843
+ presents = () if use_cache else None
844
+ if self.gradient_checkpointing and self.training:
845
+ if use_cache:
846
+ logger.warning_once(
847
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
848
+ )
849
+ use_cache = False
850
+
851
+ all_self_attentions = None
852
+ all_hidden_states = () if output_hidden_states else None
853
+ for index in range(self.num_layers):
854
+ if output_hidden_states:
855
+ all_hidden_states = all_hidden_states + (hidden_states,)
856
+
857
+ layer = self._get_layer(index)
858
+ if self.gradient_checkpointing and self.training:
859
+ layer_ret = torch.utils.checkpoint.checkpoint(
860
+ layer,
861
+ hidden_states,
862
+ attention_mask,
863
+ rotary_pos_emb,
864
+ past_key_values[index],
865
+ use_cache
866
+ )
867
+ else:
868
+ layer_ret = layer(
869
+ hidden_states,
870
+ attention_mask,
871
+ rotary_pos_emb,
872
+ kv_cache=past_key_values[index],
873
+ use_cache=use_cache
874
+ )
875
+ hidden_states, kv_cache = layer_ret
876
+ if use_cache:
877
+ presents = presents + (kv_cache,)
878
+
879
+ if output_hidden_states:
880
+ all_hidden_states = all_hidden_states + (hidden_states,)
881
+
882
+ # Final layer norm.
883
+ if self.post_layer_norm:
884
+ hidden_states = self.final_layernorm(hidden_states)
885
 
886
  if not return_dict:
887
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
 
895
 
896
  def quantize(self, weight_bit_width: int):
897
  from .quantization import quantize
898
+ quantize(self, weight_bit_width)
899
  return self
900
 
901
 
 
904
  super().__init__(config)
905
 
906
  self.max_sequence_length = config.max_length
907
+ self.model = ChatGLMModel(config, empty_init=empty_init, device=device)
908
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
909
  self.config = config
910
  self.quantized = False
911
 
 
986
  use_cache = use_cache if use_cache is not None else self.config.use_cache
987
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
988
 
989
+ transformer_outputs = self.model(
990
  input_ids=input_ids,
991
  position_ids=position_ids,
992
  attention_mask=attention_mask,
 
1000
  hidden_states = transformer_outputs[0]
1001
  if return_last_logit:
1002
  hidden_states = hidden_states[-1:]
1003
+ lm_logits = self.lm_head(hidden_states)
1004
  lm_logits = lm_logits.transpose(0, 1).contiguous()
1005
 
1006
  loss = None
 
1114
  inputs = inputs.to(self.device)
1115
  if past_key_values is not None:
1116
  past_length = past_key_values[0][0].shape[0]
1117
+ if self.model.pre_seq_len is not None:
1118
+ past_length -= self.model.pre_seq_len
1119
  inputs.position_ids += past_length
1120
  attention_mask = inputs.attention_mask
1121
  attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
 
1257
 
1258
  self.config.quantization_bit = bits
1259
 
1260
+ self.model = quantize(self.model, bits, empty_init=empty_init, device=device,
1261
  **kwargs)
1262
  return self
1263
 
 
1267
  super().__init__(config)
1268
 
1269
  self.num_labels = config.num_labels
1270
+ self.model = ChatGLMModel(config, empty_init=empty_init, device=device)
1271
 
1272
  self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
1273
  if config.classifier_dropout is not None:
 
1294
  ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
1295
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1296
 
1297
+ transformer_outputs = self.model(
1298
  input_ids=input_ids,
1299
  position_ids=position_ids,
1300
  attention_mask=attention_mask,