davidlvxin commited on
Commit
7c20485
1 Parent(s): 6b89d6b

Upload folder using huggingface_hub

Browse files
.mdl ADDED
Binary file (47 Bytes). View file
 
.msc ADDED
Binary file (1.26 kB). View file
 
.mv ADDED
@@ -0,0 +1 @@
 
 
1
+ Revision:master,CreatedAt:1725245632
config.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "THUDM/glm-4-9b-chat",
3
+ "model_type": "chatglm",
4
+ "architectures": [
5
+ "ChatGLMModel"
6
+ ],
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_chatglm.ChatGLMConfig",
9
+ "AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration",
10
+ "AutoModelForCausalLM": "modeling_chatglm.ChatGLMForConditionalGeneration",
11
+ "AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration",
12
+ "AutoModelForSequenceClassification": "modeling_chatglm.ChatGLMForSequenceClassification"
13
+ },
14
+ "add_bias_linear": false,
15
+ "add_qkv_bias": true,
16
+ "apply_query_key_layer_scaling": true,
17
+ "apply_residual_connection_post_layernorm": false,
18
+ "attention_dropout": 0.0,
19
+ "attention_softmax_in_fp32": true,
20
+ "attn_implementation": "sdpa",
21
+ "bias_dropout_fusion": true,
22
+ "ffn_hidden_size": 13696,
23
+ "fp32_residual_connection": false,
24
+ "hidden_dropout": 0.0,
25
+ "hidden_size": 4096,
26
+ "kv_channels": 128,
27
+ "layernorm_epsilon": 1e-5,
28
+ "multi_query_attention": true,
29
+ "multi_query_group_num": 2,
30
+ "num_attention_heads": 32,
31
+ "num_hidden_layers": 40,
32
+ "num_layers": 40,
33
+ "rope_ratio": 500,
34
+ "original_rope": true,
35
+ "padded_vocab_size": 151552,
36
+ "post_layer_norm": true,
37
+ "rmsnorm": true,
38
+ "seq_length": 131072,
39
+ "use_cache": true,
40
+ "torch_dtype": "bfloat16",
41
+ "transformers_version": "4.43.0",
42
+ "tie_word_embeddings": false,
43
+ "eos_token_id": [151329, 151336, 151338],
44
+ "pad_token_id": 151329
45
+ }
configuration.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"framework":"Pytorch","task":"nli"}
configuration_chatglm.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class ChatGLMConfig(PretrainedConfig):
5
+ model_type = "chatglm"
6
+
7
+ def __init__(
8
+ self,
9
+ num_layers=28,
10
+ padded_vocab_size=65024,
11
+ hidden_size=4096,
12
+ ffn_hidden_size=13696,
13
+ kv_channels=128,
14
+ num_attention_heads=32,
15
+ seq_length=2048,
16
+ hidden_dropout=0.0,
17
+ classifier_dropout=None,
18
+ attention_dropout=0.0,
19
+ layernorm_epsilon=1e-5,
20
+ rmsnorm=True,
21
+ apply_residual_connection_post_layernorm=False,
22
+ post_layer_norm=True,
23
+ add_bias_linear=False,
24
+ add_qkv_bias=False,
25
+ bias_dropout_fusion=True,
26
+ multi_query_attention=False,
27
+ multi_query_group_num=1,
28
+ rope_ratio=1,
29
+ apply_query_key_layer_scaling=True,
30
+ attention_softmax_in_fp32=True,
31
+ fp32_residual_connection=False,
32
+ **kwargs
33
+ ):
34
+ self.num_layers = num_layers
35
+ self.vocab_size = padded_vocab_size
36
+ self.padded_vocab_size = padded_vocab_size
37
+ self.hidden_size = hidden_size
38
+ self.ffn_hidden_size = ffn_hidden_size
39
+ self.kv_channels = kv_channels
40
+ self.num_attention_heads = num_attention_heads
41
+ self.seq_length = seq_length
42
+ self.hidden_dropout = hidden_dropout
43
+ self.classifier_dropout = classifier_dropout
44
+ self.attention_dropout = attention_dropout
45
+ self.layernorm_epsilon = layernorm_epsilon
46
+ self.rmsnorm = rmsnorm
47
+ self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
48
+ self.post_layer_norm = post_layer_norm
49
+ self.add_bias_linear = add_bias_linear
50
+ self.add_qkv_bias = add_qkv_bias
51
+ self.bias_dropout_fusion = bias_dropout_fusion
52
+ self.multi_query_attention = multi_query_attention
53
+ self.multi_query_group_num = multi_query_group_num
54
+ self.rope_ratio = rope_ratio
55
+ self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
56
+ self.attention_softmax_in_fp32 = attention_softmax_in_fp32
57
+ self.fp32_residual_connection = fp32_residual_connection
58
+ super().__init__(**kwargs)
generation_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "eos_token_id": [
3
+ 151329,
4
+ 151336,
5
+ 151338
6
+ ],
7
+ "pad_token_id": 151329,
8
+ "do_sample": true,
9
+ "temperature": 0.8,
10
+ "max_length": 128000,
11
+ "top_p": 0.8,
12
+ "transformers_version": "4.40.2"
13
+ }
model-00000-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5bd83129dcc5920b2787a4118a40e94b347631bdda5635f3787d575e962a687e
3
+ size 4640215520
model-00001-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6768f3a6029e72732402a9d1ae10e825d596d52cf6f74577c1d9d37d357dcad7
3
+ size 4640215600
model-00002-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90299a1c12569852ecc949bbe758d5426c0e46f8ae112824a2defdba62fb8d72
3
+ size 4640215600
model-00003-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de090bd4be3dbc11ad532b913d186fc51ccb3a1cb0b6ac7f3948fd527df03b2c
3
+ size 4640215600
model-00004-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b0fff88561f62529ada17b57edcf13a859ca45d24441364b5de4005815046d85
3
+ size 2483036544
model.safetensors.index.json ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 21043855360
4
+ },
5
+ "weight_map": {
6
+ "transformer.embedding.word_embeddings.weight": "model-00004-of-00005.safetensors",
7
+ "transformer.output_layer.weight": "model-00004-of-00005.safetensors",
8
+ "transformer.encoder.final_layernorm.weight": "model-00004-of-00005.safetensors",
9
+ "transformer.encoder.layers.30.input_layernorm.weight": "model-00003-of-00005.safetensors",
10
+ "transformer.encoder.layers.30.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
11
+ "transformer.encoder.layers.30.self_attention.query_key_value.weight": "model-00003-of-00005.safetensors",
12
+ "transformer.encoder.layers.30.self_attention.query_key_value.bias": "model-00003-of-00005.safetensors",
13
+ "transformer.encoder.layers.30.self_attention.dense.weight": "model-00003-of-00005.safetensors",
14
+ "transformer.encoder.layers.30.mlp.dense_h_to_4h.weight": "model-00003-of-00005.safetensors",
15
+ "transformer.encoder.layers.30.mlp.dense_4h_to_h.weight": "model-00003-of-00005.safetensors",
16
+ "transformer.encoder.layers.30.mlp.dense_4h_to_h.bias": "model-00003-of-00005.safetensors",
17
+ "transformer.encoder.layers.3.input_layernorm.weight": "model-00000-of-00005.safetensors",
18
+ "transformer.encoder.layers.3.post_attention_layernorm.weight": "model-00000-of-00005.safetensors",
19
+ "transformer.encoder.layers.3.self_attention.query_key_value.weight": "model-00000-of-00005.safetensors",
20
+ "transformer.encoder.layers.3.self_attention.query_key_value.bias": "model-00000-of-00005.safetensors",
21
+ "transformer.encoder.layers.3.self_attention.dense.weight": "model-00000-of-00005.safetensors",
22
+ "transformer.encoder.layers.3.mlp.dense_h_to_4h.weight": "model-00000-of-00005.safetensors",
23
+ "transformer.encoder.layers.3.mlp.dense_4h_to_h.weight": "model-00000-of-00005.safetensors",
24
+ "transformer.encoder.layers.3.mlp.dense_4h_to_h.bias": "model-00000-of-00005.safetensors",
25
+ "transformer.encoder.layers.0.input_layernorm.weight": "model-00000-of-00005.safetensors",
26
+ "transformer.encoder.layers.0.post_attention_layernorm.weight": "model-00000-of-00005.safetensors",
27
+ "transformer.encoder.layers.0.self_attention.query_key_value.weight": "model-00000-of-00005.safetensors",
28
+ "transformer.encoder.layers.0.self_attention.query_key_value.bias": "model-00000-of-00005.safetensors",
29
+ "transformer.encoder.layers.0.self_attention.dense.weight": "model-00000-of-00005.safetensors",
30
+ "transformer.encoder.layers.0.mlp.dense_h_to_4h.weight": "model-00000-of-00005.safetensors",
31
+ "transformer.encoder.layers.0.mlp.dense_4h_to_h.weight": "model-00000-of-00005.safetensors",
32
+ "transformer.encoder.layers.0.mlp.dense_4h_to_h.bias": "model-00000-of-00005.safetensors",
33
+ "transformer.encoder.layers.27.input_layernorm.weight": "model-00002-of-00005.safetensors",
34
+ "transformer.encoder.layers.27.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
35
+ "transformer.encoder.layers.27.self_attention.query_key_value.weight": "model-00002-of-00005.safetensors",
36
+ "transformer.encoder.layers.27.self_attention.query_key_value.bias": "model-00002-of-00005.safetensors",
37
+ "transformer.encoder.layers.27.self_attention.dense.weight": "model-00002-of-00005.safetensors",
38
+ "transformer.encoder.layers.27.mlp.dense_h_to_4h.weight": "model-00002-of-00005.safetensors",
39
+ "transformer.encoder.layers.27.mlp.dense_4h_to_h.weight": "model-00002-of-00005.safetensors",
40
+ "transformer.encoder.layers.27.mlp.dense_4h_to_h.bias": "model-00002-of-00005.safetensors",
41
+ "transformer.encoder.layers.9.input_layernorm.weight": "model-00000-of-00005.safetensors",
42
+ "transformer.encoder.layers.9.post_attention_layernorm.weight": "model-00000-of-00005.safetensors",
43
+ "transformer.encoder.layers.9.self_attention.query_key_value.weight": "model-00000-of-00005.safetensors",
44
+ "transformer.encoder.layers.9.self_attention.query_key_value.bias": "model-00000-of-00005.safetensors",
45
+ "transformer.encoder.layers.9.self_attention.dense.weight": "model-00000-of-00005.safetensors",
46
+ "transformer.encoder.layers.9.mlp.dense_h_to_4h.weight": "model-00000-of-00005.safetensors",
47
+ "transformer.encoder.layers.9.mlp.dense_4h_to_h.weight": "model-00000-of-00005.safetensors",
48
+ "transformer.encoder.layers.9.mlp.dense_4h_to_h.bias": "model-00000-of-00005.safetensors",
49
+ "transformer.encoder.layers.17.input_layernorm.weight": "model-00001-of-00005.safetensors",
50
+ "transformer.encoder.layers.17.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
51
+ "transformer.encoder.layers.17.self_attention.query_key_value.weight": "model-00001-of-00005.safetensors",
52
+ "transformer.encoder.layers.17.self_attention.query_key_value.bias": "model-00001-of-00005.safetensors",
53
+ "transformer.encoder.layers.17.self_attention.dense.weight": "model-00001-of-00005.safetensors",
54
+ "transformer.encoder.layers.17.mlp.dense_h_to_4h.weight": "model-00001-of-00005.safetensors",
55
+ "transformer.encoder.layers.17.mlp.dense_4h_to_h.weight": "model-00001-of-00005.safetensors",
56
+ "transformer.encoder.layers.17.mlp.dense_4h_to_h.bias": "model-00001-of-00005.safetensors",
57
+ "transformer.encoder.layers.34.input_layernorm.weight": "model-00003-of-00005.safetensors",
58
+ "transformer.encoder.layers.34.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
59
+ "transformer.encoder.layers.34.self_attention.query_key_value.weight": "model-00003-of-00005.safetensors",
60
+ "transformer.encoder.layers.34.self_attention.query_key_value.bias": "model-00003-of-00005.safetensors",
61
+ "transformer.encoder.layers.34.self_attention.dense.weight": "model-00003-of-00005.safetensors",
62
+ "transformer.encoder.layers.34.mlp.dense_h_to_4h.weight": "model-00003-of-00005.safetensors",
63
+ "transformer.encoder.layers.34.mlp.dense_4h_to_h.weight": "model-00003-of-00005.safetensors",
64
+ "transformer.encoder.layers.34.mlp.dense_4h_to_h.bias": "model-00003-of-00005.safetensors",
65
+ "transformer.encoder.layers.26.input_layernorm.weight": "model-00002-of-00005.safetensors",
66
+ "transformer.encoder.layers.26.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
67
+ "transformer.encoder.layers.26.self_attention.query_key_value.weight": "model-00002-of-00005.safetensors",
68
+ "transformer.encoder.layers.26.self_attention.query_key_value.bias": "model-00002-of-00005.safetensors",
69
+ "transformer.encoder.layers.26.self_attention.dense.weight": "model-00002-of-00005.safetensors",
70
+ "transformer.encoder.layers.26.mlp.dense_h_to_4h.weight": "model-00002-of-00005.safetensors",
71
+ "transformer.encoder.layers.26.mlp.dense_4h_to_h.weight": "model-00002-of-00005.safetensors",
72
+ "transformer.encoder.layers.26.mlp.dense_4h_to_h.bias": "model-00002-of-00005.safetensors",
73
+ "transformer.encoder.layers.6.input_layernorm.weight": "model-00000-of-00005.safetensors",
74
+ "transformer.encoder.layers.6.post_attention_layernorm.weight": "model-00000-of-00005.safetensors",
75
+ "transformer.encoder.layers.6.self_attention.query_key_value.weight": "model-00000-of-00005.safetensors",
76
+ "transformer.encoder.layers.6.self_attention.query_key_value.bias": "model-00000-of-00005.safetensors",
77
+ "transformer.encoder.layers.6.self_attention.dense.weight": "model-00000-of-00005.safetensors",
78
+ "transformer.encoder.layers.6.mlp.dense_h_to_4h.weight": "model-00000-of-00005.safetensors",
79
+ "transformer.encoder.layers.6.mlp.dense_4h_to_h.weight": "model-00000-of-00005.safetensors",
80
+ "transformer.encoder.layers.6.mlp.dense_4h_to_h.bias": "model-00000-of-00005.safetensors",
81
+ "transformer.encoder.layers.20.input_layernorm.weight": "model-00002-of-00005.safetensors",
82
+ "transformer.encoder.layers.20.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
83
+ "transformer.encoder.layers.20.self_attention.query_key_value.weight": "model-00002-of-00005.safetensors",
84
+ "transformer.encoder.layers.20.self_attention.query_key_value.bias": "model-00002-of-00005.safetensors",
85
+ "transformer.encoder.layers.20.self_attention.dense.weight": "model-00002-of-00005.safetensors",
86
+ "transformer.encoder.layers.20.mlp.dense_h_to_4h.weight": "model-00002-of-00005.safetensors",
87
+ "transformer.encoder.layers.20.mlp.dense_4h_to_h.weight": "model-00002-of-00005.safetensors",
88
+ "transformer.encoder.layers.20.mlp.dense_4h_to_h.bias": "model-00002-of-00005.safetensors",
89
+ "transformer.encoder.layers.33.input_layernorm.weight": "model-00003-of-00005.safetensors",
90
+ "transformer.encoder.layers.33.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
91
+ "transformer.encoder.layers.33.self_attention.query_key_value.weight": "model-00003-of-00005.safetensors",
92
+ "transformer.encoder.layers.33.self_attention.query_key_value.bias": "model-00003-of-00005.safetensors",
93
+ "transformer.encoder.layers.33.self_attention.dense.weight": "model-00003-of-00005.safetensors",
94
+ "transformer.encoder.layers.33.mlp.dense_h_to_4h.weight": "model-00003-of-00005.safetensors",
95
+ "transformer.encoder.layers.33.mlp.dense_4h_to_h.weight": "model-00003-of-00005.safetensors",
96
+ "transformer.encoder.layers.33.mlp.dense_4h_to_h.bias": "model-00003-of-00005.safetensors",
97
+ "transformer.encoder.layers.22.input_layernorm.weight": "model-00002-of-00005.safetensors",
98
+ "transformer.encoder.layers.22.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
99
+ "transformer.encoder.layers.22.self_attention.query_key_value.weight": "model-00002-of-00005.safetensors",
100
+ "transformer.encoder.layers.22.self_attention.query_key_value.bias": "model-00002-of-00005.safetensors",
101
+ "transformer.encoder.layers.22.self_attention.dense.weight": "model-00002-of-00005.safetensors",
102
+ "transformer.encoder.layers.22.mlp.dense_h_to_4h.weight": "model-00002-of-00005.safetensors",
103
+ "transformer.encoder.layers.22.mlp.dense_4h_to_h.weight": "model-00002-of-00005.safetensors",
104
+ "transformer.encoder.layers.22.mlp.dense_4h_to_h.bias": "model-00002-of-00005.safetensors",
105
+ "transformer.encoder.layers.1.input_layernorm.weight": "model-00000-of-00005.safetensors",
106
+ "transformer.encoder.layers.1.post_attention_layernorm.weight": "model-00000-of-00005.safetensors",
107
+ "transformer.encoder.layers.1.self_attention.query_key_value.weight": "model-00000-of-00005.safetensors",
108
+ "transformer.encoder.layers.1.self_attention.query_key_value.bias": "model-00000-of-00005.safetensors",
109
+ "transformer.encoder.layers.1.self_attention.dense.weight": "model-00000-of-00005.safetensors",
110
+ "transformer.encoder.layers.1.mlp.dense_h_to_4h.weight": "model-00000-of-00005.safetensors",
111
+ "transformer.encoder.layers.1.mlp.dense_4h_to_h.weight": "model-00000-of-00005.safetensors",
112
+ "transformer.encoder.layers.1.mlp.dense_4h_to_h.bias": "model-00000-of-00005.safetensors",
113
+ "transformer.encoder.layers.10.input_layernorm.weight": "model-00001-of-00005.safetensors",
114
+ "transformer.encoder.layers.10.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
115
+ "transformer.encoder.layers.10.self_attention.query_key_value.weight": "model-00001-of-00005.safetensors",
116
+ "transformer.encoder.layers.10.self_attention.query_key_value.bias": "model-00001-of-00005.safetensors",
117
+ "transformer.encoder.layers.10.self_attention.dense.weight": "model-00001-of-00005.safetensors",
118
+ "transformer.encoder.layers.10.mlp.dense_h_to_4h.weight": "model-00001-of-00005.safetensors",
119
+ "transformer.encoder.layers.10.mlp.dense_4h_to_h.weight": "model-00001-of-00005.safetensors",
120
+ "transformer.encoder.layers.10.mlp.dense_4h_to_h.bias": "model-00001-of-00005.safetensors",
121
+ "transformer.encoder.layers.12.input_layernorm.weight": "model-00001-of-00005.safetensors",
122
+ "transformer.encoder.layers.12.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
123
+ "transformer.encoder.layers.12.self_attention.query_key_value.weight": "model-00001-of-00005.safetensors",
124
+ "transformer.encoder.layers.12.self_attention.query_key_value.bias": "model-00001-of-00005.safetensors",
125
+ "transformer.encoder.layers.12.self_attention.dense.weight": "model-00001-of-00005.safetensors",
126
+ "transformer.encoder.layers.12.mlp.dense_h_to_4h.weight": "model-00001-of-00005.safetensors",
127
+ "transformer.encoder.layers.12.mlp.dense_4h_to_h.weight": "model-00001-of-00005.safetensors",
128
+ "transformer.encoder.layers.12.mlp.dense_4h_to_h.bias": "model-00001-of-00005.safetensors",
129
+ "transformer.encoder.layers.36.input_layernorm.weight": "model-00003-of-00005.safetensors",
130
+ "transformer.encoder.layers.36.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
131
+ "transformer.encoder.layers.36.self_attention.query_key_value.weight": "model-00003-of-00005.safetensors",
132
+ "transformer.encoder.layers.36.self_attention.query_key_value.bias": "model-00003-of-00005.safetensors",
133
+ "transformer.encoder.layers.36.self_attention.dense.weight": "model-00003-of-00005.safetensors",
134
+ "transformer.encoder.layers.36.mlp.dense_h_to_4h.weight": "model-00003-of-00005.safetensors",
135
+ "transformer.encoder.layers.36.mlp.dense_4h_to_h.weight": "model-00003-of-00005.safetensors",
136
+ "transformer.encoder.layers.36.mlp.dense_4h_to_h.bias": "model-00003-of-00005.safetensors",
137
+ "transformer.encoder.layers.35.input_layernorm.weight": "model-00003-of-00005.safetensors",
138
+ "transformer.encoder.layers.35.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
139
+ "transformer.encoder.layers.35.self_attention.query_key_value.weight": "model-00003-of-00005.safetensors",
140
+ "transformer.encoder.layers.35.self_attention.query_key_value.bias": "model-00003-of-00005.safetensors",
141
+ "transformer.encoder.layers.35.self_attention.dense.weight": "model-00003-of-00005.safetensors",
142
+ "transformer.encoder.layers.35.mlp.dense_h_to_4h.weight": "model-00003-of-00005.safetensors",
143
+ "transformer.encoder.layers.35.mlp.dense_4h_to_h.weight": "model-00003-of-00005.safetensors",
144
+ "transformer.encoder.layers.35.mlp.dense_4h_to_h.bias": "model-00003-of-00005.safetensors",
145
+ "transformer.encoder.layers.18.input_layernorm.weight": "model-00001-of-00005.safetensors",
146
+ "transformer.encoder.layers.18.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
147
+ "transformer.encoder.layers.18.self_attention.query_key_value.weight": "model-00001-of-00005.safetensors",
148
+ "transformer.encoder.layers.18.self_attention.query_key_value.bias": "model-00001-of-00005.safetensors",
149
+ "transformer.encoder.layers.18.self_attention.dense.weight": "model-00001-of-00005.safetensors",
150
+ "transformer.encoder.layers.18.mlp.dense_h_to_4h.weight": "model-00001-of-00005.safetensors",
151
+ "transformer.encoder.layers.18.mlp.dense_4h_to_h.weight": "model-00001-of-00005.safetensors",
152
+ "transformer.encoder.layers.18.mlp.dense_4h_to_h.bias": "model-00001-of-00005.safetensors",
153
+ "transformer.encoder.layers.28.input_layernorm.weight": "model-00002-of-00005.safetensors",
154
+ "transformer.encoder.layers.28.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
155
+ "transformer.encoder.layers.28.self_attention.query_key_value.weight": "model-00002-of-00005.safetensors",
156
+ "transformer.encoder.layers.28.self_attention.query_key_value.bias": "model-00002-of-00005.safetensors",
157
+ "transformer.encoder.layers.28.self_attention.dense.weight": "model-00002-of-00005.safetensors",
158
+ "transformer.encoder.layers.28.mlp.dense_h_to_4h.weight": "model-00002-of-00005.safetensors",
159
+ "transformer.encoder.layers.28.mlp.dense_4h_to_h.weight": "model-00002-of-00005.safetensors",
160
+ "transformer.encoder.layers.28.mlp.dense_4h_to_h.bias": "model-00002-of-00005.safetensors",
161
+ "transformer.encoder.layers.37.input_layernorm.weight": "model-00003-of-00005.safetensors",
162
+ "transformer.encoder.layers.37.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
163
+ "transformer.encoder.layers.37.self_attention.query_key_value.weight": "model-00003-of-00005.safetensors",
164
+ "transformer.encoder.layers.37.self_attention.query_key_value.bias": "model-00003-of-00005.safetensors",
165
+ "transformer.encoder.layers.37.self_attention.dense.weight": "model-00003-of-00005.safetensors",
166
+ "transformer.encoder.layers.37.mlp.dense_h_to_4h.weight": "model-00003-of-00005.safetensors",
167
+ "transformer.encoder.layers.37.mlp.dense_4h_to_h.weight": "model-00003-of-00005.safetensors",
168
+ "transformer.encoder.layers.37.mlp.dense_4h_to_h.bias": "model-00003-of-00005.safetensors",
169
+ "transformer.encoder.layers.7.input_layernorm.weight": "model-00000-of-00005.safetensors",
170
+ "transformer.encoder.layers.7.post_attention_layernorm.weight": "model-00000-of-00005.safetensors",
171
+ "transformer.encoder.layers.7.self_attention.query_key_value.weight": "model-00000-of-00005.safetensors",
172
+ "transformer.encoder.layers.7.self_attention.query_key_value.bias": "model-00000-of-00005.safetensors",
173
+ "transformer.encoder.layers.7.self_attention.dense.weight": "model-00000-of-00005.safetensors",
174
+ "transformer.encoder.layers.7.mlp.dense_h_to_4h.weight": "model-00000-of-00005.safetensors",
175
+ "transformer.encoder.layers.7.mlp.dense_4h_to_h.weight": "model-00000-of-00005.safetensors",
176
+ "transformer.encoder.layers.7.mlp.dense_4h_to_h.bias": "model-00000-of-00005.safetensors",
177
+ "transformer.encoder.layers.13.input_layernorm.weight": "model-00001-of-00005.safetensors",
178
+ "transformer.encoder.layers.13.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
179
+ "transformer.encoder.layers.13.self_attention.query_key_value.weight": "model-00001-of-00005.safetensors",
180
+ "transformer.encoder.layers.13.self_attention.query_key_value.bias": "model-00001-of-00005.safetensors",
181
+ "transformer.encoder.layers.13.self_attention.dense.weight": "model-00001-of-00005.safetensors",
182
+ "transformer.encoder.layers.13.mlp.dense_h_to_4h.weight": "model-00001-of-00005.safetensors",
183
+ "transformer.encoder.layers.13.mlp.dense_4h_to_h.weight": "model-00001-of-00005.safetensors",
184
+ "transformer.encoder.layers.13.mlp.dense_4h_to_h.bias": "model-00001-of-00005.safetensors",
185
+ "transformer.encoder.layers.29.input_layernorm.weight": "model-00002-of-00005.safetensors",
186
+ "transformer.encoder.layers.29.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
187
+ "transformer.encoder.layers.29.self_attention.query_key_value.weight": "model-00002-of-00005.safetensors",
188
+ "transformer.encoder.layers.29.self_attention.query_key_value.bias": "model-00002-of-00005.safetensors",
189
+ "transformer.encoder.layers.29.self_attention.dense.weight": "model-00002-of-00005.safetensors",
190
+ "transformer.encoder.layers.29.mlp.dense_h_to_4h.weight": "model-00002-of-00005.safetensors",
191
+ "transformer.encoder.layers.29.mlp.dense_4h_to_h.weight": "model-00002-of-00005.safetensors",
192
+ "transformer.encoder.layers.29.mlp.dense_4h_to_h.bias": "model-00002-of-00005.safetensors",
193
+ "transformer.encoder.layers.39.input_layernorm.weight": "model-00003-of-00005.safetensors",
194
+ "transformer.encoder.layers.39.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
195
+ "transformer.encoder.layers.39.self_attention.query_key_value.weight": "model-00003-of-00005.safetensors",
196
+ "transformer.encoder.layers.39.self_attention.query_key_value.bias": "model-00003-of-00005.safetensors",
197
+ "transformer.encoder.layers.39.self_attention.dense.weight": "model-00003-of-00005.safetensors",
198
+ "transformer.encoder.layers.39.mlp.dense_h_to_4h.weight": "model-00003-of-00005.safetensors",
199
+ "transformer.encoder.layers.39.mlp.dense_4h_to_h.weight": "model-00003-of-00005.safetensors",
200
+ "transformer.encoder.layers.39.mlp.dense_4h_to_h.bias": "model-00003-of-00005.safetensors",
201
+ "transformer.encoder.layers.31.input_layernorm.weight": "model-00003-of-00005.safetensors",
202
+ "transformer.encoder.layers.31.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
203
+ "transformer.encoder.layers.31.self_attention.query_key_value.weight": "model-00003-of-00005.safetensors",
204
+ "transformer.encoder.layers.31.self_attention.query_key_value.bias": "model-00003-of-00005.safetensors",
205
+ "transformer.encoder.layers.31.self_attention.dense.weight": "model-00003-of-00005.safetensors",
206
+ "transformer.encoder.layers.31.mlp.dense_h_to_4h.weight": "model-00003-of-00005.safetensors",
207
+ "transformer.encoder.layers.31.mlp.dense_4h_to_h.weight": "model-00003-of-00005.safetensors",
208
+ "transformer.encoder.layers.31.mlp.dense_4h_to_h.bias": "model-00003-of-00005.safetensors",
209
+ "transformer.encoder.layers.32.input_layernorm.weight": "model-00003-of-00005.safetensors",
210
+ "transformer.encoder.layers.32.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
211
+ "transformer.encoder.layers.32.self_attention.query_key_value.weight": "model-00003-of-00005.safetensors",
212
+ "transformer.encoder.layers.32.self_attention.query_key_value.bias": "model-00003-of-00005.safetensors",
213
+ "transformer.encoder.layers.32.self_attention.dense.weight": "model-00003-of-00005.safetensors",
214
+ "transformer.encoder.layers.32.mlp.dense_h_to_4h.weight": "model-00003-of-00005.safetensors",
215
+ "transformer.encoder.layers.32.mlp.dense_4h_to_h.weight": "model-00003-of-00005.safetensors",
216
+ "transformer.encoder.layers.32.mlp.dense_4h_to_h.bias": "model-00003-of-00005.safetensors",
217
+ "transformer.encoder.layers.4.input_layernorm.weight": "model-00000-of-00005.safetensors",
218
+ "transformer.encoder.layers.4.post_attention_layernorm.weight": "model-00000-of-00005.safetensors",
219
+ "transformer.encoder.layers.4.self_attention.query_key_value.weight": "model-00000-of-00005.safetensors",
220
+ "transformer.encoder.layers.4.self_attention.query_key_value.bias": "model-00000-of-00005.safetensors",
221
+ "transformer.encoder.layers.4.self_attention.dense.weight": "model-00000-of-00005.safetensors",
222
+ "transformer.encoder.layers.4.mlp.dense_h_to_4h.weight": "model-00000-of-00005.safetensors",
223
+ "transformer.encoder.layers.4.mlp.dense_4h_to_h.weight": "model-00000-of-00005.safetensors",
224
+ "transformer.encoder.layers.4.mlp.dense_4h_to_h.bias": "model-00000-of-00005.safetensors",
225
+ "transformer.encoder.layers.16.input_layernorm.weight": "model-00001-of-00005.safetensors",
226
+ "transformer.encoder.layers.16.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
227
+ "transformer.encoder.layers.16.self_attention.query_key_value.weight": "model-00001-of-00005.safetensors",
228
+ "transformer.encoder.layers.16.self_attention.query_key_value.bias": "model-00001-of-00005.safetensors",
229
+ "transformer.encoder.layers.16.self_attention.dense.weight": "model-00001-of-00005.safetensors",
230
+ "transformer.encoder.layers.16.mlp.dense_h_to_4h.weight": "model-00001-of-00005.safetensors",
231
+ "transformer.encoder.layers.16.mlp.dense_4h_to_h.weight": "model-00001-of-00005.safetensors",
232
+ "transformer.encoder.layers.16.mlp.dense_4h_to_h.bias": "model-00001-of-00005.safetensors",
233
+ "transformer.encoder.layers.19.input_layernorm.weight": "model-00001-of-00005.safetensors",
234
+ "transformer.encoder.layers.19.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
235
+ "transformer.encoder.layers.19.self_attention.query_key_value.weight": "model-00001-of-00005.safetensors",
236
+ "transformer.encoder.layers.19.self_attention.query_key_value.bias": "model-00001-of-00005.safetensors",
237
+ "transformer.encoder.layers.19.self_attention.dense.weight": "model-00001-of-00005.safetensors",
238
+ "transformer.encoder.layers.19.mlp.dense_h_to_4h.weight": "model-00001-of-00005.safetensors",
239
+ "transformer.encoder.layers.19.mlp.dense_4h_to_h.weight": "model-00001-of-00005.safetensors",
240
+ "transformer.encoder.layers.19.mlp.dense_4h_to_h.bias": "model-00001-of-00005.safetensors",
241
+ "transformer.encoder.layers.15.input_layernorm.weight": "model-00001-of-00005.safetensors",
242
+ "transformer.encoder.layers.15.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
243
+ "transformer.encoder.layers.15.self_attention.query_key_value.weight": "model-00001-of-00005.safetensors",
244
+ "transformer.encoder.layers.15.self_attention.query_key_value.bias": "model-00001-of-00005.safetensors",
245
+ "transformer.encoder.layers.15.self_attention.dense.weight": "model-00001-of-00005.safetensors",
246
+ "transformer.encoder.layers.15.mlp.dense_h_to_4h.weight": "model-00001-of-00005.safetensors",
247
+ "transformer.encoder.layers.15.mlp.dense_4h_to_h.weight": "model-00001-of-00005.safetensors",
248
+ "transformer.encoder.layers.15.mlp.dense_4h_to_h.bias": "model-00001-of-00005.safetensors",
249
+ "transformer.encoder.layers.14.input_layernorm.weight": "model-00001-of-00005.safetensors",
250
+ "transformer.encoder.layers.14.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
251
+ "transformer.encoder.layers.14.self_attention.query_key_value.weight": "model-00001-of-00005.safetensors",
252
+ "transformer.encoder.layers.14.self_attention.query_key_value.bias": "model-00001-of-00005.safetensors",
253
+ "transformer.encoder.layers.14.self_attention.dense.weight": "model-00001-of-00005.safetensors",
254
+ "transformer.encoder.layers.14.mlp.dense_h_to_4h.weight": "model-00001-of-00005.safetensors",
255
+ "transformer.encoder.layers.14.mlp.dense_4h_to_h.weight": "model-00001-of-00005.safetensors",
256
+ "transformer.encoder.layers.14.mlp.dense_4h_to_h.bias": "model-00001-of-00005.safetensors",
257
+ "transformer.encoder.layers.5.input_layernorm.weight": "model-00000-of-00005.safetensors",
258
+ "transformer.encoder.layers.5.post_attention_layernorm.weight": "model-00000-of-00005.safetensors",
259
+ "transformer.encoder.layers.5.self_attention.query_key_value.weight": "model-00000-of-00005.safetensors",
260
+ "transformer.encoder.layers.5.self_attention.query_key_value.bias": "model-00000-of-00005.safetensors",
261
+ "transformer.encoder.layers.5.self_attention.dense.weight": "model-00000-of-00005.safetensors",
262
+ "transformer.encoder.layers.5.mlp.dense_h_to_4h.weight": "model-00000-of-00005.safetensors",
263
+ "transformer.encoder.layers.5.mlp.dense_4h_to_h.weight": "model-00000-of-00005.safetensors",
264
+ "transformer.encoder.layers.5.mlp.dense_4h_to_h.bias": "model-00000-of-00005.safetensors",
265
+ "transformer.encoder.layers.38.input_layernorm.weight": "model-00003-of-00005.safetensors",
266
+ "transformer.encoder.layers.38.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
267
+ "transformer.encoder.layers.38.self_attention.query_key_value.weight": "model-00003-of-00005.safetensors",
268
+ "transformer.encoder.layers.38.self_attention.query_key_value.bias": "model-00003-of-00005.safetensors",
269
+ "transformer.encoder.layers.38.self_attention.dense.weight": "model-00003-of-00005.safetensors",
270
+ "transformer.encoder.layers.38.mlp.dense_h_to_4h.weight": "model-00003-of-00005.safetensors",
271
+ "transformer.encoder.layers.38.mlp.dense_4h_to_h.weight": "model-00003-of-00005.safetensors",
272
+ "transformer.encoder.layers.38.mlp.dense_4h_to_h.bias": "model-00003-of-00005.safetensors",
273
+ "transformer.encoder.layers.11.input_layernorm.weight": "model-00001-of-00005.safetensors",
274
+ "transformer.encoder.layers.11.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
275
+ "transformer.encoder.layers.11.self_attention.query_key_value.weight": "model-00001-of-00005.safetensors",
276
+ "transformer.encoder.layers.11.self_attention.query_key_value.bias": "model-00001-of-00005.safetensors",
277
+ "transformer.encoder.layers.11.self_attention.dense.weight": "model-00001-of-00005.safetensors",
278
+ "transformer.encoder.layers.11.mlp.dense_h_to_4h.weight": "model-00001-of-00005.safetensors",
279
+ "transformer.encoder.layers.11.mlp.dense_4h_to_h.weight": "model-00001-of-00005.safetensors",
280
+ "transformer.encoder.layers.11.mlp.dense_4h_to_h.bias": "model-00001-of-00005.safetensors",
281
+ "transformer.encoder.layers.21.input_layernorm.weight": "model-00002-of-00005.safetensors",
282
+ "transformer.encoder.layers.21.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
283
+ "transformer.encoder.layers.21.self_attention.query_key_value.weight": "model-00002-of-00005.safetensors",
284
+ "transformer.encoder.layers.21.self_attention.query_key_value.bias": "model-00002-of-00005.safetensors",
285
+ "transformer.encoder.layers.21.self_attention.dense.weight": "model-00002-of-00005.safetensors",
286
+ "transformer.encoder.layers.21.mlp.dense_h_to_4h.weight": "model-00002-of-00005.safetensors",
287
+ "transformer.encoder.layers.21.mlp.dense_4h_to_h.weight": "model-00002-of-00005.safetensors",
288
+ "transformer.encoder.layers.21.mlp.dense_4h_to_h.bias": "model-00002-of-00005.safetensors",
289
+ "transformer.encoder.layers.23.input_layernorm.weight": "model-00002-of-00005.safetensors",
290
+ "transformer.encoder.layers.23.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
291
+ "transformer.encoder.layers.23.self_attention.query_key_value.weight": "model-00002-of-00005.safetensors",
292
+ "transformer.encoder.layers.23.self_attention.query_key_value.bias": "model-00002-of-00005.safetensors",
293
+ "transformer.encoder.layers.23.self_attention.dense.weight": "model-00002-of-00005.safetensors",
294
+ "transformer.encoder.layers.23.mlp.dense_h_to_4h.weight": "model-00002-of-00005.safetensors",
295
+ "transformer.encoder.layers.23.mlp.dense_4h_to_h.weight": "model-00002-of-00005.safetensors",
296
+ "transformer.encoder.layers.23.mlp.dense_4h_to_h.bias": "model-00002-of-00005.safetensors",
297
+ "transformer.encoder.layers.24.input_layernorm.weight": "model-00002-of-00005.safetensors",
298
+ "transformer.encoder.layers.24.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
299
+ "transformer.encoder.layers.24.self_attention.query_key_value.weight": "model-00002-of-00005.safetensors",
300
+ "transformer.encoder.layers.24.self_attention.query_key_value.bias": "model-00002-of-00005.safetensors",
301
+ "transformer.encoder.layers.24.self_attention.dense.weight": "model-00002-of-00005.safetensors",
302
+ "transformer.encoder.layers.24.mlp.dense_h_to_4h.weight": "model-00002-of-00005.safetensors",
303
+ "transformer.encoder.layers.24.mlp.dense_4h_to_h.weight": "model-00002-of-00005.safetensors",
304
+ "transformer.encoder.layers.24.mlp.dense_4h_to_h.bias": "model-00002-of-00005.safetensors",
305
+ "transformer.encoder.layers.2.input_layernorm.weight": "model-00000-of-00005.safetensors",
306
+ "transformer.encoder.layers.2.post_attention_layernorm.weight": "model-00000-of-00005.safetensors",
307
+ "transformer.encoder.layers.2.self_attention.query_key_value.weight": "model-00000-of-00005.safetensors",
308
+ "transformer.encoder.layers.2.self_attention.query_key_value.bias": "model-00000-of-00005.safetensors",
309
+ "transformer.encoder.layers.2.self_attention.dense.weight": "model-00000-of-00005.safetensors",
310
+ "transformer.encoder.layers.2.mlp.dense_h_to_4h.weight": "model-00000-of-00005.safetensors",
311
+ "transformer.encoder.layers.2.mlp.dense_4h_to_h.weight": "model-00000-of-00005.safetensors",
312
+ "transformer.encoder.layers.2.mlp.dense_4h_to_h.bias": "model-00000-of-00005.safetensors",
313
+ "transformer.encoder.layers.25.input_layernorm.weight": "model-00002-of-00005.safetensors",
314
+ "transformer.encoder.layers.25.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
315
+ "transformer.encoder.layers.25.self_attention.query_key_value.weight": "model-00002-of-00005.safetensors",
316
+ "transformer.encoder.layers.25.self_attention.query_key_value.bias": "model-00002-of-00005.safetensors",
317
+ "transformer.encoder.layers.25.self_attention.dense.weight": "model-00002-of-00005.safetensors",
318
+ "transformer.encoder.layers.25.mlp.dense_h_to_4h.weight": "model-00002-of-00005.safetensors",
319
+ "transformer.encoder.layers.25.mlp.dense_4h_to_h.weight": "model-00002-of-00005.safetensors",
320
+ "transformer.encoder.layers.25.mlp.dense_4h_to_h.bias": "model-00002-of-00005.safetensors",
321
+ "transformer.encoder.layers.8.input_layernorm.weight": "model-00000-of-00005.safetensors",
322
+ "transformer.encoder.layers.8.post_attention_layernorm.weight": "model-00000-of-00005.safetensors",
323
+ "transformer.encoder.layers.8.self_attention.query_key_value.weight": "model-00000-of-00005.safetensors",
324
+ "transformer.encoder.layers.8.self_attention.query_key_value.bias": "model-00000-of-00005.safetensors",
325
+ "transformer.encoder.layers.8.self_attention.dense.weight": "model-00000-of-00005.safetensors",
326
+ "transformer.encoder.layers.8.mlp.dense_h_to_4h.weight": "model-00000-of-00005.safetensors",
327
+ "transformer.encoder.layers.8.mlp.dense_4h_to_h.weight": "model-00000-of-00005.safetensors",
328
+ "transformer.encoder.layers.8.mlp.dense_4h_to_h.bias": "model-00000-of-00005.safetensors"
329
+ }
330
+ }
modeling_chatglm.py ADDED
@@ -0,0 +1,1224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ PyTorch ChatGLM model. """
2
+
3
+ import math
4
+ import sys
5
+ import torch
6
+ import torch.utils.checkpoint
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+ from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
10
+ from torch.nn.utils import skip_init
11
+ from typing import Optional, Tuple, Union, List, Dict, Any
12
+
13
+ from transformers.modeling_outputs import (
14
+ BaseModelOutputWithPast,
15
+ CausalLMOutputWithPast,
16
+ SequenceClassifierOutputWithPast,
17
+ )
18
+ from transformers.modeling_utils import PreTrainedModel
19
+ from transformers.utils import logging, is_torch_npu_available
20
+ from transformers.generation.logits_process import LogitsProcessor
21
+ from transformers.generation.utils import ModelOutput
22
+
23
+ from .configuration_chatglm import ChatGLMConfig
24
+
25
+ try:
26
+ from transformers.utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available
27
+
28
+ if is_flash_attn_2_available():
29
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
30
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
31
+ except:
32
+ pass
33
+
34
+ from nltk.tokenize import PunktSentenceTokenizer
35
+ import re
36
+
37
+ # flags required to enable jit fusion kernels
38
+
39
+ if sys.platform != 'darwin' and not is_torch_npu_available():
40
+ torch._C._jit_set_profiling_mode(False)
41
+ torch._C._jit_set_profiling_executor(False)
42
+ torch._C._jit_override_can_fuse_on_cpu(True)
43
+ torch._C._jit_override_can_fuse_on_gpu(True)
44
+
45
+ logger = logging.get_logger(__name__)
46
+
47
+ _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
48
+ _CONFIG_FOR_DOC = "ChatGLMConfig"
49
+
50
+
51
+ def default_init(cls, *args, **kwargs):
52
+ return cls(*args, **kwargs)
53
+
54
+
55
+ class InvalidScoreLogitsProcessor(LogitsProcessor):
56
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
57
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
58
+ scores.zero_()
59
+ scores[..., 198] = 5e4
60
+ return scores
61
+
62
+
63
+ def split_tensor_along_last_dim(
64
+ tensor: torch.Tensor,
65
+ num_partitions: int,
66
+ contiguous_split_chunks: bool = False,
67
+ ) -> List[torch.Tensor]:
68
+ """Split a tensor along its last dimension.
69
+ Arguments:
70
+ tensor: input tensor.
71
+ num_partitions: number of partitions to split the tensor
72
+ contiguous_split_chunks: If True, make each chunk contiguous
73
+ in memory.
74
+ Returns:
75
+ A list of Tensors
76
+ """
77
+ # Get the size and dimension.
78
+ last_dim = tensor.dim() - 1
79
+ last_dim_size = tensor.size()[last_dim] // num_partitions
80
+ # Split.
81
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
82
+ # Note: torch.split does not create contiguous tensors by default.
83
+ if contiguous_split_chunks:
84
+ return tuple(chunk.contiguous() for chunk in tensor_list)
85
+
86
+ return tensor_list
87
+
88
+
89
+ class RotaryEmbedding(nn.Module):
90
+ def __init__(self, dim, rope_ratio=1, original_impl=False, device=None, dtype=None):
91
+ super().__init__()
92
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
93
+ self.register_buffer("inv_freq", inv_freq)
94
+ self.dim = dim
95
+ self.original_impl = original_impl
96
+ self.rope_ratio = rope_ratio
97
+
98
+ def forward_impl(
99
+ self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
100
+ ):
101
+ """Enhanced Transformer with Rotary Position Embedding.
102
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
103
+ transformers/rope/__init__.py. MIT License:
104
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
105
+ """
106
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
107
+ base = base * self.rope_ratio
108
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))
109
+
110
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
111
+ seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)
112
+
113
+ # Calculate the product of position index and $\theta_i$
114
+ idx_theta = torch.outer(seq_idx, theta).float()
115
+
116
+ cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
117
+
118
+ # this is to mimic the behaviour of complex32, else we will get different results
119
+ if dtype in (torch.float16, torch.bfloat16, torch.int8):
120
+ cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
121
+ return cache
122
+
123
+ def forward(self, max_seq_len, offset=0):
124
+ return self.forward_impl(
125
+ max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
126
+ )
127
+
128
+
129
+ @torch.jit.script
130
+ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
131
+ # x: [b, np, sq, hn]
132
+ b, np, sq, hn = x.size(0), x.size(1), x.size(2), x.size(3)
133
+ rot_dim = rope_cache.shape[-2] * 2
134
+ x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
135
+ # truncate to support variable sizes
136
+ rope_cache = rope_cache[:, :sq]
137
+ xshaped = x.reshape(b, np, sq, rot_dim // 2, 2)
138
+ rope_cache = rope_cache.view(-1, 1, sq, xshaped.size(3), 2)
139
+ x_out2 = torch.stack(
140
+ [
141
+ xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
142
+ xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
143
+ ],
144
+ -1,
145
+ )
146
+ x_out2 = x_out2.flatten(3)
147
+ return torch.cat((x_out2, x_pass), dim=-1)
148
+
149
+
150
+ class RMSNorm(torch.nn.Module):
151
+ def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
152
+ super().__init__()
153
+ self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
154
+ self.eps = eps
155
+
156
+ def forward(self, hidden_states: torch.Tensor):
157
+ input_dtype = hidden_states.dtype
158
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
159
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
160
+
161
+ return (self.weight * hidden_states).to(input_dtype)
162
+
163
+
164
+ class CoreAttention(torch.nn.Module):
165
+ def __init__(self, config: ChatGLMConfig, layer_number):
166
+ super(CoreAttention, self).__init__()
167
+ self.config = config
168
+ self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
169
+ self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
170
+ if self.apply_query_key_layer_scaling:
171
+ self.attention_softmax_in_fp32 = True
172
+ self.layer_number = max(1, layer_number)
173
+ self.is_causal = True
174
+
175
+ projection_size = config.kv_channels * config.num_attention_heads
176
+
177
+ # Per attention head and per partition values.
178
+ self.hidden_size_per_partition = projection_size
179
+ self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
180
+ self.num_attention_heads_per_partition = config.num_attention_heads
181
+
182
+ coeff = None
183
+ self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
184
+ if self.apply_query_key_layer_scaling:
185
+ coeff = self.layer_number
186
+ self.norm_factor *= coeff
187
+ self.coeff = coeff
188
+
189
+ self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
190
+
191
+ def forward(self, query_layer, key_layer, value_layer, attention_mask):
192
+ # [b, np, sq, sk]
193
+ output_size = (query_layer.size(0), query_layer.size(1), query_layer.size(2), key_layer.size(2))
194
+
195
+ # [b, np, sq, hn] -> [b * np, sq, hn]
196
+ query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1)
197
+ # [b, np, sk, hn] -> [b * np, sk, hn]
198
+ key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1)
199
+
200
+ # preallocting input tensor: [b * np, sq, sk]
201
+ matmul_input_buffer = torch.empty(
202
+ output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
203
+ device=query_layer.device
204
+ )
205
+
206
+ # Raw attention scores. [b * np, sq, sk]
207
+ matmul_result = torch.baddbmm(
208
+ matmul_input_buffer,
209
+ query_layer, # [b * np, sq, hn]
210
+ key_layer.transpose(1, 2), # [b * np, hn, sk]
211
+ beta=0.0,
212
+ alpha=(1.0 / self.norm_factor),
213
+ )
214
+
215
+ # change view to [b, np, sq, sk]
216
+ attention_scores = matmul_result.view(*output_size)
217
+
218
+ # ===========================
219
+ # Attention probs and dropout
220
+ # ===========================
221
+
222
+ # attention scores and attention mask [b, np, sq, sk]
223
+ if self.attention_softmax_in_fp32:
224
+ attention_scores = attention_scores.float()
225
+ if self.coeff is not None:
226
+ attention_scores = attention_scores * self.coeff
227
+ if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
228
+ attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
229
+ device=attention_scores.device, dtype=torch.bool)
230
+ attention_mask.tril_()
231
+ attention_mask = ~attention_mask
232
+ if attention_mask is not None:
233
+ attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
234
+ attention_probs = F.softmax(attention_scores, dim=-1)
235
+ attention_probs = attention_probs.type_as(value_layer)
236
+
237
+ # This is actually dropping out entire tokens to attend to, which might
238
+ # seem a bit unusual, but is taken from the original Transformer paper.
239
+ attention_probs = self.attention_dropout(attention_probs)
240
+
241
+ # query layer shape: [b * np, sq, hn]
242
+ # value layer shape: [b, np, sk, hn]
243
+ # attention shape: [b, np, sq, sk]
244
+ # context layer shape: [b, np, sq, hn]
245
+ output_size = (value_layer.size(0), value_layer.size(1), query_layer.size(1), value_layer.size(3))
246
+ # change view [b * np, sk, hn]
247
+ value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)
248
+ # change view [b * np, sq, sk]
249
+ attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
250
+ # matmul: [b * np, sq, hn]
251
+ context_layer = torch.bmm(attention_probs, value_layer)
252
+ # change view [b, np, sq, hn]
253
+ context_layer = context_layer.view(*output_size)
254
+ # [b, np, sq, hn] --> [b, sq, np, hn]
255
+ context_layer = context_layer.transpose(1, 2).contiguous()
256
+ # [b, sq, np, hn] --> [b, sq, hp]
257
+ splited_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
258
+ context_layer = context_layer.reshape(*splited_context_layer_shape)
259
+
260
+ return context_layer
261
+
262
+
263
+ class SdpaAttention(CoreAttention):
264
+ def forward(self, query_layer, key_layer, value_layer, attention_mask):
265
+ if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
266
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
267
+ is_causal=True,
268
+ dropout_p=self.config.attention_dropout if self.training else 0.0)
269
+ else:
270
+ if attention_mask is not None:
271
+ attention_mask = ~attention_mask
272
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
273
+ attention_mask,
274
+ dropout_p=self.config.attention_dropout if self.training else 0.0)
275
+ context_layer = context_layer.transpose(1, 2).contiguous()
276
+ splited_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
277
+ context_layer = context_layer.reshape(*splited_context_layer_shape)
278
+ return context_layer
279
+
280
+
281
+ def _get_unpad_data(attention_mask):
282
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
283
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
284
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
285
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
286
+ return (
287
+ indices,
288
+ cu_seqlens,
289
+ max_seqlen_in_batch,
290
+ )
291
+
292
+
293
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2
294
+ class FlashAttention2(CoreAttention):
295
+ def __init__(self, *args, **kwargs):
296
+ super().__init__(*args, **kwargs)
297
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
298
+
299
+ def forward(self, query_states, key_states, value_states, attention_mask):
300
+ query_states = query_states.transpose(1, 2)
301
+ key_states = key_states.transpose(1, 2)
302
+ value_states = value_states.transpose(1, 2)
303
+ batch_size, query_length = query_states.shape[:2]
304
+ if not self._flash_attn_uses_top_left_mask:
305
+ causal = self.is_causal
306
+ else:
307
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
308
+ causal = self.is_causal and query_length != 1
309
+ dropout = self.config.attention_dropout if self.training else 0.0
310
+ # Contains at least one padding token in the sequence
311
+ if attention_mask is not None:
312
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
313
+ query_states, key_states, value_states, attention_mask, query_length
314
+ )
315
+
316
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
317
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
318
+
319
+ attn_output_unpad = flash_attn_varlen_func(
320
+ query_states,
321
+ key_states,
322
+ value_states,
323
+ cu_seqlens_q=cu_seqlens_q,
324
+ cu_seqlens_k=cu_seqlens_k,
325
+ max_seqlen_q=max_seqlen_in_batch_q,
326
+ max_seqlen_k=max_seqlen_in_batch_k,
327
+ dropout_p=dropout,
328
+ softmax_scale=None,
329
+ causal=causal,
330
+ )
331
+
332
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
333
+ else:
334
+ attn_output = flash_attn_func(
335
+ query_states, key_states, value_states, dropout, softmax_scale=None, causal=causal
336
+ )
337
+ attn_output = attn_output.reshape(batch_size, query_length, self.hidden_size_per_partition).contiguous()
338
+ return attn_output
339
+
340
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
341
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
342
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
343
+
344
+ key_layer = index_first_axis(
345
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
346
+ )
347
+ value_layer = index_first_axis(
348
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
349
+ )
350
+ if query_length == kv_seq_len:
351
+ query_layer = index_first_axis(
352
+ query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads_per_partition, head_dim),
353
+ indices_k
354
+ )
355
+ cu_seqlens_q = cu_seqlens_k
356
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
357
+ indices_q = indices_k
358
+ elif query_length == 1:
359
+ max_seqlen_in_batch_q = 1
360
+ cu_seqlens_q = torch.arange(
361
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
362
+ ) # There is a memcpy here, that is very bad.
363
+ indices_q = cu_seqlens_q[:-1]
364
+ query_layer = query_layer.squeeze(1)
365
+ else:
366
+ # The -q_len: slice assumes left padding.
367
+ attention_mask = attention_mask[:, -query_length:]
368
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
369
+
370
+ return (
371
+ query_layer,
372
+ key_layer,
373
+ value_layer,
374
+ indices_q,
375
+ (cu_seqlens_q, cu_seqlens_k),
376
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
377
+ )
378
+
379
+
380
+ CORE_ATTENTION_CLASSES = {
381
+ "eager": CoreAttention,
382
+ "sdpa": SdpaAttention,
383
+ "flash_attention_2": FlashAttention2
384
+ }
385
+
386
+
387
+ class SelfAttention(torch.nn.Module):
388
+ """Parallel self-attention layer abstract class.
389
+ Self-attention layer takes input with size [s, b, h]
390
+ and returns output of the same size.
391
+ """
392
+
393
+ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
394
+ super(SelfAttention, self).__init__()
395
+ self.layer_number = max(1, layer_number)
396
+
397
+ self.projection_size = config.kv_channels * config.num_attention_heads
398
+
399
+ # Per attention head and per partition values.
400
+ self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
401
+ self.num_attention_heads_per_partition = config.num_attention_heads
402
+
403
+ self.multi_query_attention = config.multi_query_attention
404
+ self.qkv_hidden_size = 3 * self.projection_size
405
+ if self.multi_query_attention:
406
+ self.num_multi_query_groups_per_partition = config.multi_query_group_num
407
+ self.qkv_hidden_size = (
408
+ self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
409
+ )
410
+ self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
411
+ bias=config.add_bias_linear or config.add_qkv_bias,
412
+ device=device, **_config_to_kwargs(config)
413
+ )
414
+
415
+ self.core_attention = CORE_ATTENTION_CLASSES[config._attn_implementation](config, self.layer_number)
416
+
417
+ # Output.
418
+ self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
419
+ device=device, **_config_to_kwargs(config)
420
+ )
421
+
422
+ def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
423
+ if self.multi_query_attention:
424
+ num_attention_heads = self.num_multi_query_groups_per_partition
425
+ else:
426
+ num_attention_heads = self.num_attention_heads_per_partition
427
+ return torch.empty(
428
+ inference_max_sequence_len,
429
+ batch_size,
430
+ num_attention_heads,
431
+ self.hidden_size_per_attention_head,
432
+ dtype=dtype,
433
+ device=device,
434
+ )
435
+
436
+ def forward(
437
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
438
+ ):
439
+ # hidden_states: [b, sq, h]
440
+
441
+ # =================================================
442
+ # Pre-allocate memory for key-values for inference.
443
+ # =================================================
444
+ # =====================
445
+ # Query, Key, and Value
446
+ # =====================
447
+
448
+ # Attention heads [b, sq, h] --> [b, sq, (np * 3 * hn)]
449
+ mixed_x_layer = self.query_key_value(hidden_states)
450
+
451
+ if self.multi_query_attention:
452
+ (query_layer, key_layer, value_layer) = mixed_x_layer.split(
453
+ [
454
+ self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
455
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
456
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
457
+ ],
458
+ dim=-1,
459
+ )
460
+ query_layer = query_layer.view(
461
+ query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
462
+ )
463
+ key_layer = key_layer.view(
464
+ key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
465
+ )
466
+ value_layer = value_layer.view(
467
+ value_layer.size()[:-1]
468
+ + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
469
+ )
470
+ else:
471
+ new_tensor_shape = mixed_x_layer.size()[:-1] + \
472
+ (self.num_attention_heads_per_partition,
473
+ 3 * self.hidden_size_per_attention_head)
474
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
475
+
476
+ # [b, sq, np, 3 * hn] --> 3 [b, sq, np, hn]
477
+ (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
478
+
479
+ # [b, sq, np, hn] -> [b, np, sq, hn]
480
+ query_layer, key_layer, value_layer = [k.transpose(1, 2) for k in [query_layer, key_layer, value_layer]]
481
+
482
+ # apply relative positional encoding (rotary embedding)
483
+ if rotary_pos_emb is not None:
484
+ query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
485
+ key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
486
+
487
+ # adjust key and value for inference
488
+ if kv_cache is not None:
489
+ cache_k, cache_v = kv_cache
490
+ key_layer = torch.cat((cache_k, key_layer), dim=2)
491
+ value_layer = torch.cat((cache_v, value_layer), dim=2)
492
+ if use_cache:
493
+ if kv_cache is None:
494
+ kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)),
495
+ dim=1)
496
+ else:
497
+ kv_cache = (key_layer, value_layer)
498
+ else:
499
+ kv_cache = None
500
+
501
+ if self.multi_query_attention:
502
+ key_layer = key_layer.unsqueeze(2)
503
+ key_layer = key_layer.expand(
504
+ -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, -1
505
+ )
506
+ key_layer = key_layer.contiguous().view(
507
+ key_layer.size()[:1] + (self.num_attention_heads_per_partition,) + key_layer.size()[3:]
508
+ )
509
+ value_layer = value_layer.unsqueeze(2)
510
+ value_layer = value_layer.expand(
511
+ -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, -1
512
+ )
513
+ value_layer = value_layer.contiguous().view(
514
+ value_layer.size()[:1] + (self.num_attention_heads_per_partition,) + value_layer.size()[3:]
515
+ )
516
+
517
+ # ==================================
518
+ # core attention computation
519
+ # ==================================
520
+
521
+ context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
522
+
523
+ # =================
524
+ # Output. [sq, b, h]
525
+ # =================
526
+
527
+ output = self.dense(context_layer)
528
+
529
+ return output, kv_cache
530
+
531
+
532
+ def _config_to_kwargs(args):
533
+ common_kwargs = {
534
+ "dtype": args.torch_dtype,
535
+ }
536
+ return common_kwargs
537
+
538
+
539
+ class MLP(torch.nn.Module):
540
+ """MLP.
541
+ MLP will take the input with h hidden state, project it to 4*h
542
+ hidden dimension, perform nonlinear transformation, and project the
543
+ state back into h hidden dimension.
544
+ """
545
+
546
+ def __init__(self, config: ChatGLMConfig, device=None):
547
+ super(MLP, self).__init__()
548
+
549
+ self.add_bias = config.add_bias_linear
550
+
551
+ # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
552
+ self.dense_h_to_4h = nn.Linear(
553
+ config.hidden_size,
554
+ config.ffn_hidden_size * 2,
555
+ bias=self.add_bias,
556
+ device=device,
557
+ **_config_to_kwargs(config)
558
+ )
559
+
560
+ def swiglu(x):
561
+ x = torch.chunk(x, 2, dim=-1)
562
+ return F.silu(x[0]) * x[1]
563
+
564
+ self.activation_func = swiglu
565
+
566
+ # Project back to h.
567
+ self.dense_4h_to_h = nn.Linear(
568
+ config.ffn_hidden_size,
569
+ config.hidden_size,
570
+ bias=self.add_bias,
571
+ device=device,
572
+ **_config_to_kwargs(config)
573
+ )
574
+
575
+ def forward(self, hidden_states):
576
+ # [s, b, 4hp]
577
+ intermediate_parallel = self.dense_h_to_4h(hidden_states)
578
+ intermediate_parallel = self.activation_func(intermediate_parallel)
579
+ # [s, b, h]
580
+ output = self.dense_4h_to_h(intermediate_parallel)
581
+ return output
582
+
583
+
584
+ class GLMBlock(torch.nn.Module):
585
+ """A single transformer layer.
586
+ Transformer layer takes input with size [s, b, h] and returns an
587
+ output of the same size.
588
+ """
589
+
590
+ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
591
+ super(GLMBlock, self).__init__()
592
+ self.layer_number = layer_number
593
+
594
+ self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
595
+
596
+ self.fp32_residual_connection = config.fp32_residual_connection
597
+
598
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
599
+ # Layernorm on the input data.
600
+ self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
601
+ dtype=config.torch_dtype)
602
+
603
+ # Self attention.
604
+ self.self_attention = SelfAttention(config, layer_number, device=device)
605
+ self.hidden_dropout = config.hidden_dropout
606
+
607
+ # Layernorm on the attention output
608
+ self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
609
+ dtype=config.torch_dtype)
610
+
611
+ # MLP
612
+ self.mlp = MLP(config, device=device)
613
+
614
+ def forward(
615
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,
616
+ ):
617
+ # hidden_states: [s, b, h]
618
+
619
+ # Layer norm at the beginning of the transformer layer.
620
+ layernorm_output = self.input_layernorm(hidden_states)
621
+ # Self attention.
622
+ attention_output, kv_cache = self.self_attention(
623
+ layernorm_output,
624
+ attention_mask,
625
+ rotary_pos_emb,
626
+ kv_cache=kv_cache,
627
+ use_cache=use_cache
628
+ )
629
+
630
+ # Residual connection.
631
+ if self.apply_residual_connection_post_layernorm:
632
+ residual = layernorm_output
633
+ else:
634
+ residual = hidden_states
635
+
636
+ layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
637
+ layernorm_input = residual + layernorm_input
638
+
639
+ # Layer norm post the self attention.
640
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
641
+
642
+ # MLP.
643
+ mlp_output = self.mlp(layernorm_output)
644
+
645
+ # Second residual connection.
646
+ if self.apply_residual_connection_post_layernorm:
647
+ residual = layernorm_output
648
+ else:
649
+ residual = layernorm_input
650
+
651
+ output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
652
+ output = residual + output
653
+
654
+ return output, kv_cache
655
+
656
+
657
+ class GLMTransformer(torch.nn.Module):
658
+ """Transformer class."""
659
+
660
+ def __init__(self, config: ChatGLMConfig, device=None):
661
+ super(GLMTransformer, self).__init__()
662
+
663
+ self.fp32_residual_connection = config.fp32_residual_connection
664
+ self.post_layer_norm = config.post_layer_norm
665
+
666
+ # Number of layers.
667
+ self.num_layers = config.num_layers
668
+
669
+ # Transformer layers.
670
+ def build_layer(layer_number):
671
+ return GLMBlock(config, layer_number, device=device)
672
+
673
+ self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
674
+
675
+ if self.post_layer_norm:
676
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
677
+ # Final layer norm before output.
678
+ self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
679
+ dtype=config.torch_dtype)
680
+
681
+ self.gradient_checkpointing = False
682
+
683
+ def _get_layer(self, layer_number):
684
+ return self.layers[layer_number]
685
+
686
+ def forward(
687
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,
688
+ use_cache: Optional[bool] = True,
689
+ output_hidden_states: Optional[bool] = False,
690
+ ):
691
+ if not kv_caches:
692
+ kv_caches = [None for _ in range(self.num_layers)]
693
+ presents = () if use_cache else None
694
+ if self.gradient_checkpointing and self.training:
695
+ if use_cache:
696
+ logger.warning_once(
697
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
698
+ )
699
+ use_cache = False
700
+
701
+ all_self_attentions = None
702
+ all_hidden_states = () if output_hidden_states else None
703
+ for index in range(self.num_layers):
704
+ if output_hidden_states:
705
+ all_hidden_states = all_hidden_states + (hidden_states,)
706
+
707
+ layer = self._get_layer(index)
708
+ if self.gradient_checkpointing and self.training:
709
+ layer_ret = torch.utils.checkpoint.checkpoint(
710
+ layer,
711
+ hidden_states,
712
+ attention_mask,
713
+ rotary_pos_emb,
714
+ kv_caches[index],
715
+ use_cache,
716
+ use_reentrant=False
717
+ )
718
+ else:
719
+ layer_ret = layer(
720
+ hidden_states,
721
+ attention_mask,
722
+ rotary_pos_emb,
723
+ kv_cache=kv_caches[index],
724
+ use_cache=use_cache
725
+ )
726
+ hidden_states, kv_cache = layer_ret
727
+ if use_cache:
728
+ # token by token decoding, use tuple format
729
+ if kv_caches[0] is not None:
730
+ presents = presents + (kv_cache,)
731
+ # prefilling in decoding, use tensor format to save cuda memory
732
+ else:
733
+ if len(presents) == 0:
734
+ presents = kv_cache
735
+ else:
736
+ presents = torch.cat((presents, kv_cache.to(presents.device)), dim=0)
737
+
738
+ if output_hidden_states:
739
+ all_hidden_states = all_hidden_states + (hidden_states,)
740
+
741
+ # Final layer norm.
742
+ if self.post_layer_norm:
743
+ hidden_states = self.final_layernorm(hidden_states)
744
+
745
+ return hidden_states, presents, all_hidden_states, all_self_attentions
746
+
747
+
748
+ class ChatGLMPreTrainedModel(PreTrainedModel):
749
+ """
750
+ An abstract class to handle weights initialization and
751
+ a simple interface for downloading and loading pretrained models.
752
+ """
753
+
754
+ is_parallelizable = False
755
+ supports_gradient_checkpointing = True
756
+ config_class = ChatGLMConfig
757
+ base_model_prefix = "transformer"
758
+ _no_split_modules = ["GLMBlock"]
759
+ _supports_flash_attn_2 = True
760
+ _supports_sdpa = True
761
+
762
+ def _init_weights(self, module: nn.Module):
763
+ """Initialize the weights."""
764
+ return
765
+
766
+ def get_masks(self, input_ids, past_key_values, padding_mask=None):
767
+ if self.config._attn_implementation == "flash_attention_2":
768
+ if padding_mask is not None and not padding_mask.all():
769
+ return padding_mask
770
+ return None
771
+ batch_size, seq_length = input_ids.shape
772
+ full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
773
+ full_attention_mask.tril_()
774
+ past_length = 0
775
+ if past_key_values:
776
+ past_length = past_key_values[0][0].shape[2]
777
+ if past_length:
778
+ full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
779
+ device=input_ids.device), full_attention_mask), dim=-1)
780
+ if padding_mask is not None:
781
+ full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
782
+ if not past_length and padding_mask is not None:
783
+ full_attention_mask -= padding_mask.unsqueeze(-1) - 1
784
+ full_attention_mask = (full_attention_mask < 0.5).bool()
785
+ full_attention_mask.unsqueeze_(1)
786
+ return full_attention_mask
787
+
788
+ def get_position_ids(self, input_ids, device):
789
+ batch_size, seq_length = input_ids.shape
790
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
791
+ return position_ids
792
+
793
+ class Embedding(torch.nn.Module):
794
+ """Language model embeddings."""
795
+
796
+ def __init__(self, config: ChatGLMConfig, device=None):
797
+ super(Embedding, self).__init__()
798
+
799
+ self.hidden_size = config.hidden_size
800
+ # Word embeddings (parallel).
801
+ self.word_embeddings = nn.Embedding(
802
+ config.padded_vocab_size,
803
+ self.hidden_size,
804
+ dtype=config.torch_dtype,
805
+ device=device
806
+ )
807
+ self.fp32_residual_connection = config.fp32_residual_connection
808
+
809
+ def forward(self, input_ids):
810
+ # Embeddings.
811
+ words_embeddings = self.word_embeddings(input_ids)
812
+ embeddings = words_embeddings
813
+ # If the input flag for fp32 residual connection is set, convert for float.
814
+ if self.fp32_residual_connection:
815
+ embeddings = embeddings.float()
816
+ return embeddings
817
+
818
+
819
+ class ChatGLMModel(ChatGLMPreTrainedModel):
820
+ def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
821
+ super().__init__(config)
822
+ if empty_init:
823
+ init_method = skip_init
824
+ else:
825
+ init_method = default_init
826
+ init_kwargs = {}
827
+ if device is not None:
828
+ init_kwargs["device"] = device
829
+ self.embedding = init_method(Embedding, config, **init_kwargs)
830
+ self.num_layers = config.num_layers
831
+ self.multi_query_group_num = config.multi_query_group_num
832
+ self.kv_channels = config.kv_channels
833
+
834
+ # Rotary positional embeddings
835
+ self.seq_length = config.seq_length
836
+ rotary_dim = (
837
+ config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
838
+ )
839
+
840
+ self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=config.rope_ratio,
841
+ original_impl=config.original_rope,
842
+ device=device, dtype=config.torch_dtype)
843
+ self.encoder = init_method(GLMTransformer, config, **init_kwargs)
844
+ self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
845
+ dtype=config.torch_dtype, **init_kwargs)
846
+
847
+ def get_input_embeddings(self):
848
+ return self.embedding.word_embeddings
849
+
850
+ def set_input_embeddings(self, value):
851
+ self.embedding.word_embeddings = value
852
+
853
+ def forward(
854
+ self,
855
+ input_ids,
856
+ position_ids: Optional[torch.Tensor] = None,
857
+ attention_mask: Optional[torch.BoolTensor] = None,
858
+ full_attention_mask: Optional[torch.BoolTensor] = None,
859
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
860
+ inputs_embeds: Optional[torch.Tensor] = None,
861
+ use_cache: Optional[bool] = None,
862
+ output_attentions: Optional[bool] = None,
863
+ output_hidden_states: Optional[bool] = None,
864
+ return_dict: Optional[bool] = None,
865
+ ):
866
+ output_hidden_states = (
867
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
868
+ )
869
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
870
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
871
+
872
+ batch_size, seq_length = input_ids.shape
873
+
874
+ if inputs_embeds is None:
875
+ inputs_embeds = self.embedding(input_ids)
876
+
877
+ if full_attention_mask is None:
878
+ if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
879
+ full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
880
+
881
+ # Rotary positional embeddings
882
+ rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
883
+ if position_ids is not None:
884
+ rotary_pos_emb = rotary_pos_emb[position_ids]
885
+ else:
886
+ rotary_pos_emb = rotary_pos_emb[None, :seq_length]
887
+
888
+ # Run encoder.
889
+ hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
890
+ inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
891
+ kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
892
+ )
893
+ if presents is not None and type(presents) is torch.Tensor:
894
+ presents = presents.split(1, dim=0)
895
+ presents = list(presents)
896
+ presents = [list(x.squeeze(0).split(1, dim=0)) for x in presents]
897
+ presents = [tuple([x.squeeze(0) for x in y]) for y in presents]
898
+ presents = tuple(presents)
899
+
900
+ if not return_dict:
901
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
902
+
903
+ return BaseModelOutputWithPast(
904
+ last_hidden_state=hidden_states,
905
+ past_key_values=presents,
906
+ hidden_states=all_hidden_states,
907
+ attentions=all_self_attentions,
908
+ )
909
+
910
+
911
+ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
912
+ def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
913
+ super().__init__(config)
914
+
915
+ self.max_sequence_length = config.max_length
916
+ self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
917
+ self.config = config
918
+
919
+ def _update_model_kwargs_for_generation(
920
+ self,
921
+ outputs: ModelOutput,
922
+ model_kwargs: Dict[str, Any],
923
+ is_encoder_decoder: bool = False,
924
+ ) -> Dict[str, Any]:
925
+ # update past_key_values
926
+ cache_name, cache = self._extract_past_from_model_output(outputs)
927
+ model_kwargs[cache_name] = cache
928
+
929
+ # update attention mask
930
+ if "attention_mask" in model_kwargs:
931
+ attention_mask = model_kwargs["attention_mask"]
932
+ model_kwargs["attention_mask"] = torch.cat(
933
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
934
+ )
935
+
936
+ # update position ids
937
+ if "position_ids" in model_kwargs:
938
+ position_ids = model_kwargs["position_ids"]
939
+ new_position_id = position_ids[..., -1:].clone()
940
+ new_position_id += 1
941
+ model_kwargs["position_ids"] = torch.cat(
942
+ [position_ids, new_position_id], dim=-1
943
+ )
944
+
945
+ model_kwargs["is_first_forward"] = False
946
+ return model_kwargs
947
+
948
+ def prepare_inputs_for_generation(
949
+ self,
950
+ input_ids: torch.LongTensor,
951
+ past_key_values: Optional[torch.Tensor] = None,
952
+ attention_mask: Optional[torch.Tensor] = None,
953
+ position_ids: Optional[torch.Tensor] = None,
954
+ use_cache: Optional[bool] = None,
955
+ is_first_forward: bool = True,
956
+ **kwargs
957
+ ) -> dict:
958
+ # only last token for input_ids if past is not None
959
+ if position_ids is None:
960
+ position_ids = self.get_position_ids(input_ids, device=input_ids.device)
961
+ if not is_first_forward:
962
+ if past_key_values is not None:
963
+ position_ids = position_ids[..., -1:]
964
+ input_ids = input_ids[:, -1:]
965
+ return {
966
+ "input_ids": input_ids,
967
+ "past_key_values": past_key_values,
968
+ "position_ids": position_ids,
969
+ "attention_mask": attention_mask,
970
+ "return_last_logit": True,
971
+ "use_cache": use_cache
972
+ }
973
+
974
+ def forward(
975
+ self,
976
+ input_ids: Optional[torch.Tensor] = None,
977
+ position_ids: Optional[torch.Tensor] = None,
978
+ attention_mask: Optional[torch.Tensor] = None,
979
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
980
+ inputs_embeds: Optional[torch.Tensor] = None,
981
+ labels: Optional[torch.Tensor] = None,
982
+ use_cache: Optional[bool] = None,
983
+ output_attentions: Optional[bool] = None,
984
+ output_hidden_states: Optional[bool] = None,
985
+ return_dict: Optional[bool] = None,
986
+ return_last_logit: Optional[bool] = False,
987
+ ):
988
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
989
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
990
+
991
+ transformer_outputs = self.transformer(
992
+ input_ids=input_ids,
993
+ position_ids=position_ids,
994
+ attention_mask=attention_mask,
995
+ past_key_values=past_key_values,
996
+ inputs_embeds=inputs_embeds,
997
+ use_cache=use_cache,
998
+ output_hidden_states=output_hidden_states,
999
+ return_dict=return_dict,
1000
+ )
1001
+
1002
+ hidden_states = transformer_outputs[0]
1003
+ if return_last_logit:
1004
+ hidden_states = hidden_states[:, -1:]
1005
+ lm_logits = self.transformer.output_layer(hidden_states)
1006
+
1007
+ loss = None
1008
+ if labels is not None:
1009
+ lm_logits = lm_logits.to(torch.float32)
1010
+
1011
+ # Shift so that tokens < n predict n
1012
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1013
+ shift_labels = labels[..., 1:].contiguous()
1014
+ # Flatten the tokens
1015
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
1016
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1017
+
1018
+ lm_logits = lm_logits.to(hidden_states.dtype)
1019
+ loss = loss.to(hidden_states.dtype)
1020
+
1021
+ if not return_dict:
1022
+ output = (lm_logits,) + transformer_outputs[1:]
1023
+ return ((loss,) + output) if loss is not None else output
1024
+
1025
+ return CausalLMOutputWithPast(
1026
+ loss=loss,
1027
+ logits=lm_logits,
1028
+ past_key_values=transformer_outputs.past_key_values,
1029
+ hidden_states=transformer_outputs.hidden_states,
1030
+ attentions=transformer_outputs.attentions,
1031
+ )
1032
+
1033
+ @staticmethod
1034
+ def _reorder_cache(
1035
+ past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
1036
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
1037
+ """
1038
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1039
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1040
+ beam_idx at every generation step.
1041
+ Output shares the same memory storage as `past`.
1042
+ """
1043
+ return tuple(
1044
+ (
1045
+ layer_past[0].index_select(0, beam_idx.to(layer_past[0].device)),
1046
+ layer_past[1].index_select(0, beam_idx.to(layer_past[1].device)),
1047
+ )
1048
+ for layer_past in past
1049
+ )
1050
+
1051
+ @torch.inference_mode()
1052
+ def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
1053
+ max_length: int = 131072, num_beams=1, do_sample=True, top_p=0.7, temperature=0.95,
1054
+ **kwargs):
1055
+ if history is None:
1056
+ history = []
1057
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1058
+ "temperature": temperature, **kwargs}
1059
+ inputs = tokenizer.build_chat_input(query, history=history, role=role)
1060
+ inputs = inputs.to(self.device)
1061
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
1062
+ tokenizer.get_command("<|observation|>")]
1063
+ outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
1064
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1065
+ response = tokenizer.decode(outputs).strip()
1066
+ history.append({"role": role, "content": query})
1067
+ return response, history
1068
+
1069
+ def query_longcite(self, context, query, tokenizer, max_input_length=128000, max_new_tokens=1024, temperature=0.95):
1070
+
1071
+ def text_split_by_punctuation(original_text, return_dict=False):
1072
+ # text = re.sub(r'([a-z])\.([A-Z])', r'\1. \2', original_text) # separate period without space
1073
+ text = original_text
1074
+ custom_sent_tokenizer = PunktSentenceTokenizer(text)
1075
+ punctuations = r"([。;!?])" # For Chinese support
1076
+
1077
+ separated = custom_sent_tokenizer.tokenize(text)
1078
+ separated = sum([re.split(punctuations, s) for s in separated], [])
1079
+ # Put the punctuations back to the sentence
1080
+ for i in range(1, len(separated)):
1081
+ if re.match(punctuations, separated[i]):
1082
+ separated[i-1] += separated[i]
1083
+ separated[i] = ''
1084
+
1085
+ separated = [s for s in separated if s != ""]
1086
+ if len(separated) == 1:
1087
+ separated = original_text.split('\n\n')
1088
+ separated = [s.strip() for s in separated if s.strip() != ""]
1089
+ if not return_dict:
1090
+ return separated
1091
+ else:
1092
+ pos = 0
1093
+ res = []
1094
+ for i, sent in enumerate(separated):
1095
+ st = original_text.find(sent, pos)
1096
+ assert st != -1, sent
1097
+ ed = st + len(sent)
1098
+ res.append(
1099
+ {
1100
+ 'c_idx': i,
1101
+ 'content': sent,
1102
+ 'start_idx': st,
1103
+ 'end_idx': ed,
1104
+ }
1105
+ )
1106
+ pos = ed
1107
+ return res
1108
+
1109
+ def get_prompt(context, question):
1110
+ sents = text_split_by_punctuation(context, return_dict=True)
1111
+ splited_context = ""
1112
+ for i, s in enumerate(sents):
1113
+ st, ed = s['start_idx'], s['end_idx']
1114
+ assert s['content'] == context[st:ed], s
1115
+ ed = sents[i+1]['start_idx'] if i < len(sents)-1 else len(context)
1116
+ sents[i] = {
1117
+ 'content': context[st:ed],
1118
+ 'start': st,
1119
+ 'end': ed,
1120
+ 'c_idx': s['c_idx'],
1121
+ }
1122
+ splited_context += f"<C{i}>"+context[st:ed]
1123
+ prompt = '''Please answer the user's question based on the following document. When a sentence S in your response uses information from some chunks in the document (i.e., <C{s1}>-<C_{e1}>, <C{s2}>-<C{e2}>, ...), please append these chunk numbers to S in the format "<statement>{S}<cite>[{s1}-{e1}][{s2}-{e2}]...</cite></statement>". You must answer in the same language as the user's question.\n\n[Document Start]\n%s\n[Document End]\n\n%s''' % (splited_context, question)
1124
+ return prompt, sents, splited_context
1125
+
1126
+ def get_citations(statement, sents):
1127
+ c_texts = re.findall(r'<cite>(.*?)</cite>', statement, re.DOTALL)
1128
+ spans = sum([re.findall(r"\[([0-9]+\-[0-9]+)\]", c_text, re.DOTALL) for c_text in c_texts], [])
1129
+ statement = re.sub(r'<cite>(.*?)</cite>', '', statement, flags=re.DOTALL)
1130
+ merged_citations = []
1131
+ for i, s in enumerate(spans):
1132
+ try:
1133
+ st, ed = [int(x) for x in s.split('-')]
1134
+ if st > len(sents) - 1 or ed < st:
1135
+ continue
1136
+ st, ed = max(0, st), min(ed, len(sents)-1)
1137
+ assert st <= ed, str(c_texts) + '\t' + str(len(sents))
1138
+ if len(merged_citations) > 0 and st == merged_citations[-1]['end_sentence_idx'] + 1:
1139
+ merged_citations[-1].update({
1140
+ "end_sentence_idx": ed,
1141
+ 'end_char_idx': sents[ed]['end'],
1142
+ 'cite': ''.join([x['content'] for x in sents[merged_citations[-1]['start_sentence_idx']:ed+1]]),
1143
+ })
1144
+ else:
1145
+ merged_citations.append({
1146
+ "start_sentence_idx": st,
1147
+ "end_sentence_idx": ed,
1148
+ "start_char_idx": sents[st]['start'],
1149
+ 'end_char_idx': sents[ed]['end'],
1150
+ 'cite': ''.join([x['content'] for x in sents[st:ed+1]]),
1151
+ })
1152
+ except:
1153
+ print(c_texts, len(sents), statement)
1154
+ raise
1155
+ return statement, merged_citations[:3]
1156
+
1157
+ def postprocess(answer, sents, splited_context):
1158
+ res = []
1159
+ pos = 0
1160
+ new_answer = ""
1161
+ while True:
1162
+ st = answer.find("<statement>", pos)
1163
+ if st == -1:
1164
+ st = len(answer)
1165
+ ed = answer.find("</statement>", st)
1166
+ statement = answer[pos:st]
1167
+ if len(statement.strip()) > 5:
1168
+ res.append({
1169
+ "statement": statement,
1170
+ "citation": []
1171
+ })
1172
+ new_answer += f"<statement>{statement}<cite></cite></statement>"
1173
+ else:
1174
+ res.append({
1175
+ "statement": statement,
1176
+ "citation": None,
1177
+ })
1178
+ new_answer += statement
1179
+
1180
+ if ed == -1:
1181
+ break
1182
+
1183
+ statement = answer[st+len("<statement>"):ed]
1184
+ if len(statement.strip()) > 0:
1185
+ statement, citations = get_citations(statement, sents)
1186
+ res.append({
1187
+ "statement": statement,
1188
+ "citation": citations
1189
+ })
1190
+ c_str = ''.join(['[{}-{}]'.format(c['start_sentence_idx'], c['end_sentence_idx']) for c in citations])
1191
+ new_answer += f"<statement>{statement}<cite>{c_str}</cite></statement>"
1192
+ else:
1193
+ res.append({
1194
+ "statement": statement,
1195
+ "citation": None,
1196
+ })
1197
+ new_answer += statement
1198
+ pos = ed + len("</statement>")
1199
+ return {
1200
+ "answer": new_answer.strip(),
1201
+ "statements_with_citations": [x for x in res if x['citation'] is not None],
1202
+ "splited_context": splited_context.strip(),
1203
+ "all_statements": res,
1204
+ }
1205
+
1206
+ def truncate_from_middle(prompt, max_input_length=None, tokenizer=None):
1207
+ if max_input_length is None:
1208
+ return prompt
1209
+ else:
1210
+ assert tokenizer is not None
1211
+ tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
1212
+ if len(tokenized_prompt) > max_input_length:
1213
+ half = int(max_input_length/2)
1214
+ prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
1215
+ return prompt
1216
+
1217
+ prompt, sents, splited_context = get_prompt(context, query)
1218
+ prompt = truncate_from_middle(prompt, max_input_length, tokenizer)
1219
+ output, _ = self.chat(tokenizer, prompt, history=[], max_new_tokens=max_new_tokens, temperature=temperature)
1220
+ result = postprocess(output, sents, splited_context)
1221
+ return result
1222
+
1223
+
1224
+
tokenization_chatglm.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import regex as re
2
+ import base64
3
+ import os
4
+ import json
5
+ import tiktoken
6
+ from transformers import PreTrainedTokenizer
7
+ from typing import List, Optional, Union, Dict
8
+ from transformers import PreTrainedTokenizer
9
+ from transformers.utils import logging, PaddingStrategy
10
+ from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
11
+
12
+
13
+ class ChatGLM4Tokenizer(PreTrainedTokenizer):
14
+ vocab_files_names = {"vocab_file": "tokenizer.model"}
15
+ model_input_names = ["input_ids", "attention_mask", "position_ids"]
16
+
17
+ def __init__(
18
+ self,
19
+ vocab_file,
20
+ padding_side="left",
21
+ clean_up_tokenization_spaces=False,
22
+ encode_special_tokens=False,
23
+ **kwargs
24
+ ):
25
+ self.name = "GLMTokenizer"
26
+ self.vocab_file = vocab_file
27
+ pat_str = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
28
+ self.pat_str = re.compile(pat_str)
29
+ self.encode_special_tokens = encode_special_tokens
30
+
31
+ mergeable_ranks = {}
32
+ with open(vocab_file) as f:
33
+ for line in f:
34
+ token, rank = line.strip().split()
35
+ rank = int(rank)
36
+ token = base64.b64decode(token)
37
+ mergeable_ranks[token] = rank
38
+
39
+ self.mergeable_ranks = mergeable_ranks
40
+ self.special_tokens = ["<|endoftext|>", "[MASK]", "[gMASK]", "[sMASK]", "<sop>", "<eop>", "<|system|>",
41
+ "<|user|>", "<|assistant|>", "<|observation|>", "<|begin_of_image|>", "<|end_of_image|>",
42
+ "<|begin_of_video|>", "<|end_of_video|>"]
43
+
44
+ self.special_tokens = {
45
+ token: idx for idx, token in enumerate(self.special_tokens, start=len(mergeable_ranks))
46
+ }
47
+ self.special_token_ids = {idx: token for token, idx in self.special_tokens.items()}
48
+
49
+ self.tokenizer = tiktoken.Encoding(
50
+ name="my_tokenizer",
51
+ pat_str=pat_str,
52
+ mergeable_ranks=mergeable_ranks,
53
+ special_tokens=self.special_tokens
54
+ )
55
+ self.decoder = {rank: token for token, rank in mergeable_ranks.items()}
56
+ self.n_words = len(self.decoder) + len(self.special_tokens)
57
+
58
+ super().__init__(
59
+ padding_side=padding_side,
60
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
61
+ **kwargs
62
+ )
63
+
64
+ def get_command(self, token):
65
+ assert token in self.special_tokens
66
+ return self.special_tokens[token]
67
+
68
+ @property
69
+ def vocab_size(self):
70
+ return self.n_words
71
+
72
+ @property
73
+ def eos_token_id(self):
74
+ return self.get_command("<|endoftext|>")
75
+
76
+ def get_vocab(self):
77
+ """ Returns vocab as a dict """
78
+ vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
79
+ vocab.update(self.added_tokens_encoder)
80
+ return vocab
81
+
82
+ def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
83
+ """
84
+ Converts a sequence of tokens in a single string.
85
+ """
86
+ text = ""
87
+ temp = b""
88
+ for t in tokens:
89
+ if isinstance(t, str):
90
+ if temp:
91
+ text += temp.decode("utf-8", errors="replace")
92
+ temp = b""
93
+ text += t
94
+ elif isinstance(t, bytes):
95
+ temp += t
96
+ else:
97
+ raise TypeError("token should only be of type types or str")
98
+ if temp:
99
+ text += temp.decode("utf-8", errors="replace")
100
+ return text
101
+
102
+ def _tokenize(self, text, **kwargs):
103
+ tokens = []
104
+ if self.encode_special_tokens:
105
+ ids = self.tokenizer.encode(text, allowed_special="all")
106
+ else:
107
+ ids = self.tokenizer.encode(text, disallowed_special=())
108
+ for t in ids:
109
+ tokens.append(self.decoder[t])
110
+ return tokens
111
+
112
+ def _convert_token_to_id(self, token):
113
+ """ Converts a token (str) in an id using the vocab. """
114
+ if token in self.special_tokens:
115
+ return self.special_tokens[token]
116
+ return self.mergeable_ranks[token]
117
+
118
+ def _convert_id_to_token(self, index):
119
+ """Converts an index (integer) in a token (str) using the vocab."""
120
+ if index in self.special_token_ids:
121
+ return self.special_token_ids[index]
122
+ return self.decoder[index]
123
+
124
+ def save_vocabulary(self, save_directory, filename_prefix=None):
125
+ """
126
+ Save the vocabulary and special tokens file to a directory.
127
+
128
+ Args:
129
+ save_directory (`str`):
130
+ The directory in which to save the vocabulary.
131
+ filename_prefix (`str`, *optional*):
132
+ An optional prefix to add to the named of the saved files.
133
+
134
+ Returns:
135
+ `Tuple(str)`: Paths to the files saved.
136
+ """
137
+ if os.path.isdir(save_directory):
138
+ vocab_file = os.path.join(
139
+ save_directory, self.vocab_files_names["vocab_file"]
140
+ )
141
+ else:
142
+ vocab_file = save_directory
143
+
144
+ with open(self.vocab_file, 'rb') as fin:
145
+ proto_str = fin.read()
146
+
147
+ with open(vocab_file, "wb") as writer:
148
+ writer.write(proto_str)
149
+
150
+ return (vocab_file,)
151
+
152
+ def get_prefix_tokens(self):
153
+ prefix_tokens = [self.get_command("[gMASK]"), self.get_command("<sop>")]
154
+ return prefix_tokens
155
+
156
+ def build_single_message(self, role, metadata, message):
157
+ assert role in ["system", "user", "assistant", "observation"], role
158
+ role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
159
+ message_tokens = self.tokenizer.encode(message, disallowed_special=())
160
+ tokens = role_tokens + message_tokens
161
+ return tokens
162
+
163
+ def build_chat_input(self, query, history=None, role="user"):
164
+ if history is None:
165
+ history = []
166
+ input_ids = []
167
+ for item in history:
168
+ content = item["content"]
169
+ if item["role"] == "system" and "tools" in item:
170
+ for function in item["tools"]:
171
+ content += f"\n\n## {function['name']}\n\n{json.dumps(function, ensure_ascii=False, indent=4)}"
172
+ content += "\n在调用上述函数时,请使用 Json 格式表示调用的参数。"
173
+ input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
174
+ input_ids.extend(self.build_single_message(role, "", query))
175
+ input_ids.extend([self.get_command("<|assistant|>")])
176
+ return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
177
+
178
+ def build_inputs_with_special_tokens(
179
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
180
+ ) -> List[int]:
181
+ """
182
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
183
+ adding special tokens. A BERT sequence has the following format:
184
+
185
+ - single sequence: `[CLS] X [SEP]`
186
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
187
+
188
+ Args:
189
+ token_ids_0 (`List[int]`):
190
+ List of IDs to which the special tokens will be added.
191
+ token_ids_1 (`List[int]`, *optional*):
192
+ Optional second list of IDs for sequence pairs.
193
+
194
+ Returns:
195
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
196
+ """
197
+ prefix_tokens = self.get_prefix_tokens()
198
+ token_ids_0 = prefix_tokens + token_ids_0
199
+ if token_ids_1 is not None:
200
+ token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("<eos>")]
201
+ return token_ids_0
202
+
203
+ def _pad(
204
+ self,
205
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
206
+ max_length: Optional[int] = None,
207
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
208
+ pad_to_multiple_of: Optional[int] = None,
209
+ return_attention_mask: Optional[bool] = None,
210
+ ) -> dict:
211
+ """
212
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
213
+
214
+ Args:
215
+ encoded_inputs:
216
+ Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
217
+ max_length: maximum length of the returned list and optionally padding length (see below).
218
+ Will truncate by taking into account the special tokens.
219
+ padding_strategy: PaddingStrategy to use for padding.
220
+
221
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
222
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
223
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
224
+ The tokenizer padding sides are defined in self.padding_side:
225
+
226
+ - 'left': pads on the left of the sequences
227
+ - 'right': pads on the right of the sequences
228
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
229
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
230
+ `>= 7.5` (Volta).
231
+ return_attention_mask:
232
+ (optional) Set to False to avoid returning attention mask (default: set to model specifics)
233
+ """
234
+ # Load from model defaults
235
+ assert self.padding_side == "left"
236
+
237
+ required_input = encoded_inputs[self.model_input_names[0]]
238
+ seq_length = len(required_input)
239
+
240
+ if padding_strategy == PaddingStrategy.LONGEST:
241
+ max_length = len(required_input)
242
+
243
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
244
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
245
+
246
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
247
+
248
+ # Initialize attention mask if not present.
249
+ if "attention_mask" not in encoded_inputs:
250
+ encoded_inputs["attention_mask"] = [1] * seq_length
251
+
252
+ if "position_ids" not in encoded_inputs:
253
+ encoded_inputs["position_ids"] = list(range(seq_length))
254
+
255
+ if needs_to_be_padded:
256
+ difference = max_length - len(required_input)
257
+
258
+ if "attention_mask" in encoded_inputs:
259
+ encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
260
+ if "position_ids" in encoded_inputs:
261
+ encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
262
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
263
+
264
+ return encoded_inputs
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a493598071550244b2ee7f26118f3edec2150b9dfa967929a99052ac83fe716
3
+ size 2623634
tokenizer_config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "151329": {
4
+ "content": "<|endoftext|>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ }
11
+ },
12
+ "auto_map": {
13
+ "AutoTokenizer": [
14
+ "tokenization_chatglm.ChatGLM4Tokenizer",
15
+ null
16
+ ]
17
+ },
18
+ "chat_template": "{% for message in messages %}{% if loop.first %}[gMASK]<sop><|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
19
+ "clean_up_tokenization_spaces": false,
20
+ "do_lower_case": false,
21
+ "eos_token": "<|endoftext|>",
22
+ "model_max_length": 1000000000000000019884624838656,
23
+ "padding_side": "left",
24
+ "remove_space": false,
25
+ "tokenizer_class": "ChatGLM4Tokenizer"
26
+ }
vllm_inference.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from vllm import LLM, SamplingParams
3
+ from nltk.tokenize import PunktSentenceTokenizer
4
+ import re
5
+ import torch
6
+
7
+ class LongCiteModel(LLM):
8
+
9
+ @torch.inference_mode()
10
+ def chat(self, tokenizer, query: str, history=None, role="user",
11
+ max_new_tokens=None, top_p=0.7, temperature=0.95):
12
+ if history is None:
13
+ history = []
14
+ inputs = tokenizer.build_chat_input(query, history=history, role=role)
15
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), tokenizer.get_command("<|observation|>")]
16
+ generation_params = SamplingParams(
17
+ temperature=temperature,
18
+ top_p=top_p,
19
+ max_tokens=max_new_tokens,
20
+ stop_token_ids=eos_token_id,
21
+ )
22
+ input_ids = inputs.input_ids[0].tolist()
23
+ outputs = self.generate(sampling_params=generation_params, prompt_token_ids=[input_ids])
24
+ response = tokenizer.decode(outputs[0].outputs[0].token_ids[:-1])
25
+ history.append({"role": role, "content": query})
26
+ return response, history
27
+
28
+ def query_longcite(self, context, query, tokenizer, max_input_length=128000, max_new_tokens=1024, temperature=0.95):
29
+
30
+ def text_split_by_punctuation(original_text, return_dict=False):
31
+ # text = re.sub(r'([a-z])\.([A-Z])', r'\1. \2', original_text) # separate period without space
32
+ text = original_text
33
+ custom_sent_tokenizer = PunktSentenceTokenizer(text)
34
+ punctuations = r"([。;!?])" # For Chinese support
35
+
36
+ separated = custom_sent_tokenizer.tokenize(text)
37
+ separated = sum([re.split(punctuations, s) for s in separated], [])
38
+ # Put the punctuations back to the sentence
39
+ for i in range(1, len(separated)):
40
+ if re.match(punctuations, separated[i]):
41
+ separated[i-1] += separated[i]
42
+ separated[i] = ''
43
+
44
+ separated = [s for s in separated if s != ""]
45
+ if len(separated) == 1:
46
+ separated = original_text.split('\n\n')
47
+ separated = [s.strip() for s in separated if s.strip() != ""]
48
+ if not return_dict:
49
+ return separated
50
+ else:
51
+ pos = 0
52
+ res = []
53
+ for i, sent in enumerate(separated):
54
+ st = original_text.find(sent, pos)
55
+ assert st != -1, sent
56
+ ed = st + len(sent)
57
+ res.append(
58
+ {
59
+ 'c_idx': i,
60
+ 'content': sent,
61
+ 'start_idx': st,
62
+ 'end_idx': ed,
63
+ }
64
+ )
65
+ pos = ed
66
+ return res
67
+
68
+ def get_prompt(context, question):
69
+ sents = text_split_by_punctuation(context, return_dict=True)
70
+ splited_context = ""
71
+ for i, s in enumerate(sents):
72
+ st, ed = s['start_idx'], s['end_idx']
73
+ assert s['content'] == context[st:ed], s
74
+ ed = sents[i+1]['start_idx'] if i < len(sents)-1 else len(context)
75
+ sents[i] = {
76
+ 'content': context[st:ed],
77
+ 'start': st,
78
+ 'end': ed,
79
+ 'c_idx': s['c_idx'],
80
+ }
81
+ splited_context += f"<C{i}>"+context[st:ed]
82
+ prompt = '''Please answer the user's question based on the following document. When a sentence S in your response uses information from some chunks in the document (i.e., <C{s1}>-<C_{e1}>, <C{s2}>-<C{e2}>, ...), please append these chunk numbers to S in the format "<statement>{S}<cite>[{s1}-{e1}][{s2}-{e2}]...</cite></statement>". You must answer in the same language as the user's question.\n\n[Document Start]\n%s\n[Document End]\n\n%s''' % (splited_context, question)
83
+ return prompt, sents, splited_context
84
+
85
+ def get_citations(statement, sents):
86
+ c_texts = re.findall(r'<cite>(.*?)</cite>', statement, re.DOTALL)
87
+ spans = sum([re.findall(r"\[([0-9]+\-[0-9]+)\]", c_text, re.DOTALL) for c_text in c_texts], [])
88
+ statement = re.sub(r'<cite>(.*?)</cite>', '', statement, flags=re.DOTALL)
89
+ merged_citations = []
90
+ for i, s in enumerate(spans):
91
+ try:
92
+ st, ed = [int(x) for x in s.split('-')]
93
+ if st > len(sents) - 1 or ed < st:
94
+ continue
95
+ st, ed = max(0, st), min(ed, len(sents)-1)
96
+ assert st <= ed, str(c_texts) + '\t' + str(len(sents))
97
+ if len(merged_citations) > 0 and st == merged_citations[-1]['end_sentence_idx'] + 1:
98
+ merged_citations[-1].update({
99
+ "end_sentence_idx": ed,
100
+ 'end_char_idx': sents[ed]['end'],
101
+ 'cite': ''.join([x['content'] for x in sents[merged_citations[-1]['start_sentence_idx']:ed+1]]),
102
+ })
103
+ else:
104
+ merged_citations.append({
105
+ "start_sentence_idx": st,
106
+ "end_sentence_idx": ed,
107
+ "start_char_idx": sents[st]['start'],
108
+ 'end_char_idx': sents[ed]['end'],
109
+ 'cite': ''.join([x['content'] for x in sents[st:ed+1]]),
110
+ })
111
+ except:
112
+ print(c_texts, len(sents), statement)
113
+ raise
114
+ return statement, merged_citations[:3]
115
+
116
+ def postprocess(answer, sents, splited_context):
117
+ res = []
118
+ pos = 0
119
+ new_answer = ""
120
+ while True:
121
+ st = answer.find("<statement>", pos)
122
+ if st == -1:
123
+ st = len(answer)
124
+ ed = answer.find("</statement>", st)
125
+ statement = answer[pos:st]
126
+ if len(statement.strip()) > 5:
127
+ res.append({
128
+ "statement": statement,
129
+ "citation": []
130
+ })
131
+ new_answer += f"<statement>{statement}<cite></cite></statement>"
132
+ else:
133
+ res.append({
134
+ "statement": statement,
135
+ "citation": None,
136
+ })
137
+ new_answer += statement
138
+
139
+ if ed == -1:
140
+ break
141
+
142
+ statement = answer[st+len("<statement>"):ed]
143
+ if len(statement.strip()) > 0:
144
+ statement, citations = get_citations(statement, sents)
145
+ res.append({
146
+ "statement": statement,
147
+ "citation": citations
148
+ })
149
+ c_str = ''.join(['[{}-{}]'.format(c['start_sentence_idx'], c['end_sentence_idx']) for c in citations])
150
+ new_answer += f"<statement>{statement}<cite>{c_str}</cite></statement>"
151
+ else:
152
+ res.append({
153
+ "statement": statement,
154
+ "citation": None,
155
+ })
156
+ new_answer += statement
157
+ pos = ed + len("</statement>")
158
+ return {
159
+ "answer": new_answer.strip(),
160
+ "statements_with_citations": [x for x in res if x['citation'] is not None],
161
+ "splited_context": splited_context.strip(),
162
+ "all_statements": res,
163
+ }
164
+
165
+ def truncate_from_middle(prompt, max_input_length=None, tokenizer=None):
166
+ if max_input_length is None:
167
+ return prompt
168
+ else:
169
+ assert tokenizer is not None
170
+ tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
171
+ if len(tokenized_prompt) > max_input_length:
172
+ half = int(max_input_length/2)
173
+ prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
174
+ return prompt
175
+
176
+ prompt, sents, splited_context = get_prompt(context, query)
177
+ prompt = truncate_from_middle(prompt, max_input_length, tokenizer)
178
+ output, _ = self.chat(tokenizer, prompt, history=[], max_new_tokens=max_new_tokens, temperature=temperature)
179
+ result = postprocess(output, sents, splited_context)
180
+ return result
181
+
182
+
183
+ if __name__ == "__main__":
184
+ model_path = "THUDM/LongCite-glm4-9b"
185
+ model = LongCiteModel(
186
+ model= model_path,
187
+ dtype=torch.bfloat16,
188
+ trust_remote_code=True,
189
+ tensor_parallel_size=1,
190
+ max_model_len=131072,
191
+ gpu_memory_utilization=1,
192
+ )
193
+ tokenizer = model.get_tokenizer()
194
+
195
+ context = '''
196
+ W. Russell Todd, 94, United States Army general (b. 1928). February 13. Tim Aymar, 59, heavy metal singer (Pharaoh) (b. 1963). Marshall \"Eddie\" Conway, 76, Black Panther Party leader (b. 1946). Roger Bonk, 78, football player (North Dakota Fighting Sioux, Winnipeg Blue Bombers) (b. 1944). Conrad Dobler, 72, football player (St. Louis Cardinals, New Orleans Saints, Buffalo Bills) (b. 1950). Brian DuBois, 55, baseball player (Detroit Tigers) (b. 1967). Robert Geddes, 99, architect, dean of the Princeton University School of Architecture (1965–1982) (b. 1923). Tom Luddy, 79, film producer (Barfly, The Secret Garden), co-founder of the Telluride Film Festival (b. 1943). David Singmaster, 84, mathematician (b. 1938).
197
+ '''
198
+ query = "What was Robert Geddes' profession?"
199
+ result = model.query_longcite(context, query, tokenizer=tokenizer, max_input_length=128000, max_new_tokens=1024)
200
+
201
+ print("Answer:")
202
+ print(result['answer'])
203
+ print('\n')
204
+ print("Statement with citations:" )
205
+ print(json.dumps(result['statements_with_citations'], indent=2, ensure_ascii=False))
206
+ print('\n')
207
+ print("Context (divided into sentences):")
208
+ print(result['splited_context'])