Update modeling_chatglm.py
Browse files- modeling_chatglm.py +68 -16
modeling_chatglm.py
CHANGED
@@ -753,9 +753,20 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
753 |
|
754 |
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
|
755 |
dtype=config.torch_dtype)
|
756 |
-
|
757 |
-
|
758 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
759 |
self.pre_seq_len = config.pre_seq_len
|
760 |
self.prefix_projection = config.prefix_projection
|
761 |
if self.pre_seq_len is not None:
|
@@ -827,10 +838,50 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
827 |
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
828 |
|
829 |
# Run encoder.
|
830 |
-
|
831 |
-
|
832 |
-
|
833 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
834 |
|
835 |
if not return_dict:
|
836 |
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
@@ -844,7 +895,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
844 |
|
845 |
def quantize(self, weight_bit_width: int):
|
846 |
from .quantization import quantize
|
847 |
-
quantize(self
|
848 |
return self
|
849 |
|
850 |
|
@@ -853,7 +904,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
853 |
super().__init__(config)
|
854 |
|
855 |
self.max_sequence_length = config.max_length
|
856 |
-
self.
|
|
|
857 |
self.config = config
|
858 |
self.quantized = False
|
859 |
|
@@ -934,7 +986,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
934 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
935 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
936 |
|
937 |
-
transformer_outputs = self.
|
938 |
input_ids=input_ids,
|
939 |
position_ids=position_ids,
|
940 |
attention_mask=attention_mask,
|
@@ -948,7 +1000,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
948 |
hidden_states = transformer_outputs[0]
|
949 |
if return_last_logit:
|
950 |
hidden_states = hidden_states[-1:]
|
951 |
-
lm_logits = self.
|
952 |
lm_logits = lm_logits.transpose(0, 1).contiguous()
|
953 |
|
954 |
loss = None
|
@@ -1062,8 +1114,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1062 |
inputs = inputs.to(self.device)
|
1063 |
if past_key_values is not None:
|
1064 |
past_length = past_key_values[0][0].shape[0]
|
1065 |
-
if self.
|
1066 |
-
past_length -= self.
|
1067 |
inputs.position_ids += past_length
|
1068 |
attention_mask = inputs.attention_mask
|
1069 |
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
|
@@ -1205,7 +1257,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1205 |
|
1206 |
self.config.quantization_bit = bits
|
1207 |
|
1208 |
-
self.
|
1209 |
**kwargs)
|
1210 |
return self
|
1211 |
|
@@ -1215,7 +1267,7 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
|
|
1215 |
super().__init__(config)
|
1216 |
|
1217 |
self.num_labels = config.num_labels
|
1218 |
-
self.
|
1219 |
|
1220 |
self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
|
1221 |
if config.classifier_dropout is not None:
|
@@ -1242,7 +1294,7 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
|
|
1242 |
) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
|
1243 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1244 |
|
1245 |
-
transformer_outputs = self.
|
1246 |
input_ids=input_ids,
|
1247 |
position_ids=position_ids,
|
1248 |
attention_mask=attention_mask,
|
|
|
753 |
|
754 |
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
|
755 |
dtype=config.torch_dtype)
|
756 |
+
|
757 |
+
# Transformer layers.
|
758 |
+
def build_layer(layer_number):
|
759 |
+
return GLMBlock(config, layer_number, device=device)
|
760 |
+
|
761 |
+
self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
|
762 |
+
self.num_layers = config.num_layers
|
763 |
+
|
764 |
+
if self.post_layer_norm:
|
765 |
+
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
766 |
+
# Final layer norm before output.
|
767 |
+
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
768 |
+
dtype=config.torch_dtype)
|
769 |
+
|
770 |
self.pre_seq_len = config.pre_seq_len
|
771 |
self.prefix_projection = config.prefix_projection
|
772 |
if self.pre_seq_len is not None:
|
|
|
838 |
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
839 |
|
840 |
# Run encoder.
|
841 |
+
if not past_key_values:
|
842 |
+
past_key_values = [None for _ in range(self.num_layers)]
|
843 |
+
presents = () if use_cache else None
|
844 |
+
if self.gradient_checkpointing and self.training:
|
845 |
+
if use_cache:
|
846 |
+
logger.warning_once(
|
847 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
848 |
+
)
|
849 |
+
use_cache = False
|
850 |
+
|
851 |
+
all_self_attentions = None
|
852 |
+
all_hidden_states = () if output_hidden_states else None
|
853 |
+
for index in range(self.num_layers):
|
854 |
+
if output_hidden_states:
|
855 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
856 |
+
|
857 |
+
layer = self._get_layer(index)
|
858 |
+
if self.gradient_checkpointing and self.training:
|
859 |
+
layer_ret = torch.utils.checkpoint.checkpoint(
|
860 |
+
layer,
|
861 |
+
hidden_states,
|
862 |
+
attention_mask,
|
863 |
+
rotary_pos_emb,
|
864 |
+
past_key_values[index],
|
865 |
+
use_cache
|
866 |
+
)
|
867 |
+
else:
|
868 |
+
layer_ret = layer(
|
869 |
+
hidden_states,
|
870 |
+
attention_mask,
|
871 |
+
rotary_pos_emb,
|
872 |
+
kv_cache=past_key_values[index],
|
873 |
+
use_cache=use_cache
|
874 |
+
)
|
875 |
+
hidden_states, kv_cache = layer_ret
|
876 |
+
if use_cache:
|
877 |
+
presents = presents + (kv_cache,)
|
878 |
+
|
879 |
+
if output_hidden_states:
|
880 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
881 |
+
|
882 |
+
# Final layer norm.
|
883 |
+
if self.post_layer_norm:
|
884 |
+
hidden_states = self.final_layernorm(hidden_states)
|
885 |
|
886 |
if not return_dict:
|
887 |
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
|
|
895 |
|
896 |
def quantize(self, weight_bit_width: int):
|
897 |
from .quantization import quantize
|
898 |
+
quantize(self, weight_bit_width)
|
899 |
return self
|
900 |
|
901 |
|
|
|
904 |
super().__init__(config)
|
905 |
|
906 |
self.max_sequence_length = config.max_length
|
907 |
+
self.model = ChatGLMModel(config, empty_init=empty_init, device=device)
|
908 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
909 |
self.config = config
|
910 |
self.quantized = False
|
911 |
|
|
|
986 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
987 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
988 |
|
989 |
+
transformer_outputs = self.model(
|
990 |
input_ids=input_ids,
|
991 |
position_ids=position_ids,
|
992 |
attention_mask=attention_mask,
|
|
|
1000 |
hidden_states = transformer_outputs[0]
|
1001 |
if return_last_logit:
|
1002 |
hidden_states = hidden_states[-1:]
|
1003 |
+
lm_logits = self.lm_head(hidden_states)
|
1004 |
lm_logits = lm_logits.transpose(0, 1).contiguous()
|
1005 |
|
1006 |
loss = None
|
|
|
1114 |
inputs = inputs.to(self.device)
|
1115 |
if past_key_values is not None:
|
1116 |
past_length = past_key_values[0][0].shape[0]
|
1117 |
+
if self.model.pre_seq_len is not None:
|
1118 |
+
past_length -= self.model.pre_seq_len
|
1119 |
inputs.position_ids += past_length
|
1120 |
attention_mask = inputs.attention_mask
|
1121 |
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
|
|
|
1257 |
|
1258 |
self.config.quantization_bit = bits
|
1259 |
|
1260 |
+
self.model = quantize(self.model, bits, empty_init=empty_init, device=device,
|
1261 |
**kwargs)
|
1262 |
return self
|
1263 |
|
|
|
1267 |
super().__init__(config)
|
1268 |
|
1269 |
self.num_labels = config.num_labels
|
1270 |
+
self.model = ChatGLMModel(config, empty_init=empty_init, device=device)
|
1271 |
|
1272 |
self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
|
1273 |
if config.classifier_dropout is not None:
|
|
|
1294 |
) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
|
1295 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1296 |
|
1297 |
+
transformer_outputs = self.model(
|
1298 |
input_ids=input_ids,
|
1299 |
position_ids=position_ids,
|
1300 |
attention_mask=attention_mask,
|