huskyhong commited on
Commit
5c14c40
·
1 Parent(s): 0aa538b

Upload 18 files

Browse files
config.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/root/workspace/models/Qwen-7B-Chat",
3
+ "activation": "swiglu",
4
+ "apply_residual_connection_post_layernorm": false,
5
+ "architectures": [
6
+ "QWenLMHeadModel"
7
+ ],
8
+ "attn_pdrop": 0.0,
9
+ "auto_map": {
10
+ "AutoConfig": "configuration_qwen.QWenConfig",
11
+ "AutoModelForCausalLM": "modeling_qwen.QWenLMHeadModel"
12
+ },
13
+ "bf16": true,
14
+ "bias_dropout_fusion": true,
15
+ "bos_token_id": 151643,
16
+ "embd_pdrop": 0.0,
17
+ "eos_token_id": 151643,
18
+ "ffn_hidden_size": 22016,
19
+ "fp16": false,
20
+ "fp32": false,
21
+ "initializer_range": 0.02,
22
+ "kv_channels": 128,
23
+ "layer_norm_epsilon": 1e-06,
24
+ "model_type": "qwen",
25
+ "n_embd": 4096,
26
+ "n_head": 32,
27
+ "n_inner": null,
28
+ "n_layer": 32,
29
+ "n_positions": 6144,
30
+ "no_bias": true,
31
+ "onnx_safe": null,
32
+ "padded_vocab_size": 151936,
33
+ "params_dtype": "torch.bfloat16",
34
+ "pos_emb": "rotary",
35
+ "resid_pdrop": 0.1,
36
+ "rotary_emb_base": 10000,
37
+ "rotary_pct": 1.0,
38
+ "scale_attn_weights": true,
39
+ "seq_length": 2048,
40
+ "tie_word_embeddings": false,
41
+ "tokenizer_type": "QWenTokenizer",
42
+ "torch_dtype": "bfloat16",
43
+ "transformers_version": "4.34.0",
44
+ "use_cache": true,
45
+ "use_dynamic_ntk": true,
46
+ "use_flash_attn": true,
47
+ "use_logn_attn": true,
48
+ "vocab_size": 151936
49
+ }
configuration_qwen.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from transformers import PretrainedConfig
7
+
8
+
9
+ class QWenConfig(PretrainedConfig):
10
+ model_type = "qwen"
11
+ keys_to_ignore_at_inference = ["past_key_values"]
12
+ attribute_map = {
13
+ "hidden_size": "n_embd",
14
+ "num_attention_heads": "n_head",
15
+ "max_position_embeddings": "n_positions",
16
+ "num_hidden_layers": "n_layer",
17
+ }
18
+
19
+ def __init__(
20
+ self,
21
+ vocab_size=151851,
22
+ n_embd=4096,
23
+ n_layer=32,
24
+ n_head=32,
25
+ n_inner=None,
26
+ embd_pdrop=0.0,
27
+ attn_pdrop=0.0,
28
+ layer_norm_epsilon=1e-5,
29
+ initializer_range=0.02,
30
+ scale_attn_weights=True,
31
+ use_cache=True,
32
+ eos_token_id=151643,
33
+ apply_residual_connection_post_layernorm=False,
34
+ bf16=False,
35
+ fp16=False,
36
+ fp32=False,
37
+ kv_channels=128,
38
+ rotary_pct=1.0,
39
+ rotary_emb_base=10000,
40
+ use_dynamic_ntk=False,
41
+ use_logn_attn=False,
42
+ use_flash_attn=True,
43
+ ffn_hidden_size=22016,
44
+ no_bias=True,
45
+ tie_word_embeddings=False,
46
+ **kwargs,
47
+ ):
48
+ self.eos_token_id = eos_token_id
49
+ super().__init__(
50
+ eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
51
+ )
52
+
53
+ self.vocab_size = vocab_size
54
+ self.n_embd = n_embd
55
+ self.n_layer = n_layer
56
+ self.n_head = n_head
57
+ self.n_inner = n_inner
58
+ self.embd_pdrop = embd_pdrop
59
+ self.attn_pdrop = attn_pdrop
60
+ self.layer_norm_epsilon = layer_norm_epsilon
61
+ self.initializer_range = initializer_range
62
+ self.scale_attn_weights = scale_attn_weights
63
+ self.use_cache = use_cache
64
+ self.apply_residual_connection_post_layernorm = (
65
+ apply_residual_connection_post_layernorm
66
+ )
67
+ self.bf16 = bf16
68
+ self.fp16 = fp16
69
+ self.fp32 = fp32
70
+ self.kv_channels = kv_channels
71
+ self.rotary_pct = rotary_pct
72
+ self.rotary_emb_base = rotary_emb_base
73
+ self.use_dynamic_ntk = use_dynamic_ntk
74
+ self.use_logn_attn = use_logn_attn
75
+ self.use_flash_attn = use_flash_attn
76
+ self.ffn_hidden_size = ffn_hidden_size
77
+ self.no_bias = no_bias
78
+ self.tie_word_embeddings = tie_word_embeddings
generation_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chat_format": "chatml",
3
+ "decay_bound": 0.0,
4
+ "decay_factor": 1.0,
5
+ "do_sample": true,
6
+ "eos_token_id": 151643,
7
+ "factual_nucleus_sampling": false,
8
+ "max_context_size": 1024,
9
+ "max_generate_size": 512,
10
+ "max_new_tokens": 512,
11
+ "pad_token_id": 151643,
12
+ "stop_words_ids": [
13
+ [
14
+ 151643
15
+ ]
16
+ ],
17
+ "top_k": 0,
18
+ "top_p": 0.8,
19
+ "transformers_version": "4.34.0"
20
+ }
model-00001-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d6eacaf5b398a941f4a885dcfdefacd46d58f47df3df7afa161f7bc7a44d60b0
3
+ size 1964066488
model-00002-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e6cf727ee6f28697cd479d153a47f4601d3a4dbbe34bc75ff69cb8e493fc767
3
+ size 2023960808
model-00003-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05e713f1bfd491c5c88d492fc9fdc174b8308cd57d76a109dec856aae04f92aa
3
+ size 2023960816
model-00004-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8642f192a2ddbc54ca5d6252897ec278d76726cf8af5f3f74aac0251675f020
3
+ size 2023960848
model-00005-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:299ab045af98ff22818253b537c96ab8e5c2802517bbc983e0a8abe704cbaed6
3
+ size 2023960848
model-00006-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:96e76e710497bd1aa78a52a6f294c6d5e2a71e6a524414ab1d2827ec2e032423
3
+ size 2023960848
model-00007-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1accec9f9040595afe4cc59b65581001c3f35818c695f6f6f3140846287637ee
3
+ size 2023960848
model-00008-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5f6ce9190c8ad5b56b5fd5e5478e5294a8e413a47a5d3118197d97dbc9b2cfb
3
+ size 1334845784
model.safetensors.index.json ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 15442649088
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "model-00008-of-00008.safetensors",
7
+ "transformer.h.0.attn.c_attn.bias": "model-00001-of-00008.safetensors",
8
+ "transformer.h.0.attn.c_attn.weight": "model-00001-of-00008.safetensors",
9
+ "transformer.h.0.attn.c_proj.weight": "model-00001-of-00008.safetensors",
10
+ "transformer.h.0.ln_1.weight": "model-00001-of-00008.safetensors",
11
+ "transformer.h.0.ln_2.weight": "model-00001-of-00008.safetensors",
12
+ "transformer.h.0.mlp.c_proj.weight": "model-00001-of-00008.safetensors",
13
+ "transformer.h.0.mlp.w1.weight": "model-00001-of-00008.safetensors",
14
+ "transformer.h.0.mlp.w2.weight": "model-00001-of-00008.safetensors",
15
+ "transformer.h.1.attn.c_attn.bias": "model-00001-of-00008.safetensors",
16
+ "transformer.h.1.attn.c_attn.weight": "model-00001-of-00008.safetensors",
17
+ "transformer.h.1.attn.c_proj.weight": "model-00001-of-00008.safetensors",
18
+ "transformer.h.1.ln_1.weight": "model-00001-of-00008.safetensors",
19
+ "transformer.h.1.ln_2.weight": "model-00001-of-00008.safetensors",
20
+ "transformer.h.1.mlp.c_proj.weight": "model-00002-of-00008.safetensors",
21
+ "transformer.h.1.mlp.w1.weight": "model-00001-of-00008.safetensors",
22
+ "transformer.h.1.mlp.w2.weight": "model-00001-of-00008.safetensors",
23
+ "transformer.h.10.attn.c_attn.bias": "model-00003-of-00008.safetensors",
24
+ "transformer.h.10.attn.c_attn.weight": "model-00003-of-00008.safetensors",
25
+ "transformer.h.10.attn.c_proj.weight": "model-00003-of-00008.safetensors",
26
+ "transformer.h.10.ln_1.weight": "model-00003-of-00008.safetensors",
27
+ "transformer.h.10.ln_2.weight": "model-00003-of-00008.safetensors",
28
+ "transformer.h.10.mlp.c_proj.weight": "model-00003-of-00008.safetensors",
29
+ "transformer.h.10.mlp.w1.weight": "model-00003-of-00008.safetensors",
30
+ "transformer.h.10.mlp.w2.weight": "model-00003-of-00008.safetensors",
31
+ "transformer.h.11.attn.c_attn.bias": "model-00003-of-00008.safetensors",
32
+ "transformer.h.11.attn.c_attn.weight": "model-00003-of-00008.safetensors",
33
+ "transformer.h.11.attn.c_proj.weight": "model-00003-of-00008.safetensors",
34
+ "transformer.h.11.ln_1.weight": "model-00003-of-00008.safetensors",
35
+ "transformer.h.11.ln_2.weight": "model-00003-of-00008.safetensors",
36
+ "transformer.h.11.mlp.c_proj.weight": "model-00004-of-00008.safetensors",
37
+ "transformer.h.11.mlp.w1.weight": "model-00003-of-00008.safetensors",
38
+ "transformer.h.11.mlp.w2.weight": "model-00003-of-00008.safetensors",
39
+ "transformer.h.12.attn.c_attn.bias": "model-00004-of-00008.safetensors",
40
+ "transformer.h.12.attn.c_attn.weight": "model-00004-of-00008.safetensors",
41
+ "transformer.h.12.attn.c_proj.weight": "model-00004-of-00008.safetensors",
42
+ "transformer.h.12.ln_1.weight": "model-00004-of-00008.safetensors",
43
+ "transformer.h.12.ln_2.weight": "model-00004-of-00008.safetensors",
44
+ "transformer.h.12.mlp.c_proj.weight": "model-00004-of-00008.safetensors",
45
+ "transformer.h.12.mlp.w1.weight": "model-00004-of-00008.safetensors",
46
+ "transformer.h.12.mlp.w2.weight": "model-00004-of-00008.safetensors",
47
+ "transformer.h.13.attn.c_attn.bias": "model-00004-of-00008.safetensors",
48
+ "transformer.h.13.attn.c_attn.weight": "model-00004-of-00008.safetensors",
49
+ "transformer.h.13.attn.c_proj.weight": "model-00004-of-00008.safetensors",
50
+ "transformer.h.13.ln_1.weight": "model-00004-of-00008.safetensors",
51
+ "transformer.h.13.ln_2.weight": "model-00004-of-00008.safetensors",
52
+ "transformer.h.13.mlp.c_proj.weight": "model-00004-of-00008.safetensors",
53
+ "transformer.h.13.mlp.w1.weight": "model-00004-of-00008.safetensors",
54
+ "transformer.h.13.mlp.w2.weight": "model-00004-of-00008.safetensors",
55
+ "transformer.h.14.attn.c_attn.bias": "model-00004-of-00008.safetensors",
56
+ "transformer.h.14.attn.c_attn.weight": "model-00004-of-00008.safetensors",
57
+ "transformer.h.14.attn.c_proj.weight": "model-00004-of-00008.safetensors",
58
+ "transformer.h.14.ln_1.weight": "model-00004-of-00008.safetensors",
59
+ "transformer.h.14.ln_2.weight": "model-00004-of-00008.safetensors",
60
+ "transformer.h.14.mlp.c_proj.weight": "model-00004-of-00008.safetensors",
61
+ "transformer.h.14.mlp.w1.weight": "model-00004-of-00008.safetensors",
62
+ "transformer.h.14.mlp.w2.weight": "model-00004-of-00008.safetensors",
63
+ "transformer.h.15.attn.c_attn.bias": "model-00004-of-00008.safetensors",
64
+ "transformer.h.15.attn.c_attn.weight": "model-00004-of-00008.safetensors",
65
+ "transformer.h.15.attn.c_proj.weight": "model-00004-of-00008.safetensors",
66
+ "transformer.h.15.ln_1.weight": "model-00004-of-00008.safetensors",
67
+ "transformer.h.15.ln_2.weight": "model-00004-of-00008.safetensors",
68
+ "transformer.h.15.mlp.c_proj.weight": "model-00004-of-00008.safetensors",
69
+ "transformer.h.15.mlp.w1.weight": "model-00004-of-00008.safetensors",
70
+ "transformer.h.15.mlp.w2.weight": "model-00004-of-00008.safetensors",
71
+ "transformer.h.16.attn.c_attn.bias": "model-00004-of-00008.safetensors",
72
+ "transformer.h.16.attn.c_attn.weight": "model-00004-of-00008.safetensors",
73
+ "transformer.h.16.attn.c_proj.weight": "model-00004-of-00008.safetensors",
74
+ "transformer.h.16.ln_1.weight": "model-00004-of-00008.safetensors",
75
+ "transformer.h.16.ln_2.weight": "model-00004-of-00008.safetensors",
76
+ "transformer.h.16.mlp.c_proj.weight": "model-00005-of-00008.safetensors",
77
+ "transformer.h.16.mlp.w1.weight": "model-00004-of-00008.safetensors",
78
+ "transformer.h.16.mlp.w2.weight": "model-00004-of-00008.safetensors",
79
+ "transformer.h.17.attn.c_attn.bias": "model-00005-of-00008.safetensors",
80
+ "transformer.h.17.attn.c_attn.weight": "model-00005-of-00008.safetensors",
81
+ "transformer.h.17.attn.c_proj.weight": "model-00005-of-00008.safetensors",
82
+ "transformer.h.17.ln_1.weight": "model-00005-of-00008.safetensors",
83
+ "transformer.h.17.ln_2.weight": "model-00005-of-00008.safetensors",
84
+ "transformer.h.17.mlp.c_proj.weight": "model-00005-of-00008.safetensors",
85
+ "transformer.h.17.mlp.w1.weight": "model-00005-of-00008.safetensors",
86
+ "transformer.h.17.mlp.w2.weight": "model-00005-of-00008.safetensors",
87
+ "transformer.h.18.attn.c_attn.bias": "model-00005-of-00008.safetensors",
88
+ "transformer.h.18.attn.c_attn.weight": "model-00005-of-00008.safetensors",
89
+ "transformer.h.18.attn.c_proj.weight": "model-00005-of-00008.safetensors",
90
+ "transformer.h.18.ln_1.weight": "model-00005-of-00008.safetensors",
91
+ "transformer.h.18.ln_2.weight": "model-00005-of-00008.safetensors",
92
+ "transformer.h.18.mlp.c_proj.weight": "model-00005-of-00008.safetensors",
93
+ "transformer.h.18.mlp.w1.weight": "model-00005-of-00008.safetensors",
94
+ "transformer.h.18.mlp.w2.weight": "model-00005-of-00008.safetensors",
95
+ "transformer.h.19.attn.c_attn.bias": "model-00005-of-00008.safetensors",
96
+ "transformer.h.19.attn.c_attn.weight": "model-00005-of-00008.safetensors",
97
+ "transformer.h.19.attn.c_proj.weight": "model-00005-of-00008.safetensors",
98
+ "transformer.h.19.ln_1.weight": "model-00005-of-00008.safetensors",
99
+ "transformer.h.19.ln_2.weight": "model-00005-of-00008.safetensors",
100
+ "transformer.h.19.mlp.c_proj.weight": "model-00005-of-00008.safetensors",
101
+ "transformer.h.19.mlp.w1.weight": "model-00005-of-00008.safetensors",
102
+ "transformer.h.19.mlp.w2.weight": "model-00005-of-00008.safetensors",
103
+ "transformer.h.2.attn.c_attn.bias": "model-00002-of-00008.safetensors",
104
+ "transformer.h.2.attn.c_attn.weight": "model-00002-of-00008.safetensors",
105
+ "transformer.h.2.attn.c_proj.weight": "model-00002-of-00008.safetensors",
106
+ "transformer.h.2.ln_1.weight": "model-00002-of-00008.safetensors",
107
+ "transformer.h.2.ln_2.weight": "model-00002-of-00008.safetensors",
108
+ "transformer.h.2.mlp.c_proj.weight": "model-00002-of-00008.safetensors",
109
+ "transformer.h.2.mlp.w1.weight": "model-00002-of-00008.safetensors",
110
+ "transformer.h.2.mlp.w2.weight": "model-00002-of-00008.safetensors",
111
+ "transformer.h.20.attn.c_attn.bias": "model-00005-of-00008.safetensors",
112
+ "transformer.h.20.attn.c_attn.weight": "model-00005-of-00008.safetensors",
113
+ "transformer.h.20.attn.c_proj.weight": "model-00005-of-00008.safetensors",
114
+ "transformer.h.20.ln_1.weight": "model-00005-of-00008.safetensors",
115
+ "transformer.h.20.ln_2.weight": "model-00005-of-00008.safetensors",
116
+ "transformer.h.20.mlp.c_proj.weight": "model-00005-of-00008.safetensors",
117
+ "transformer.h.20.mlp.w1.weight": "model-00005-of-00008.safetensors",
118
+ "transformer.h.20.mlp.w2.weight": "model-00005-of-00008.safetensors",
119
+ "transformer.h.21.attn.c_attn.bias": "model-00005-of-00008.safetensors",
120
+ "transformer.h.21.attn.c_attn.weight": "model-00005-of-00008.safetensors",
121
+ "transformer.h.21.attn.c_proj.weight": "model-00005-of-00008.safetensors",
122
+ "transformer.h.21.ln_1.weight": "model-00005-of-00008.safetensors",
123
+ "transformer.h.21.ln_2.weight": "model-00005-of-00008.safetensors",
124
+ "transformer.h.21.mlp.c_proj.weight": "model-00006-of-00008.safetensors",
125
+ "transformer.h.21.mlp.w1.weight": "model-00005-of-00008.safetensors",
126
+ "transformer.h.21.mlp.w2.weight": "model-00005-of-00008.safetensors",
127
+ "transformer.h.22.attn.c_attn.bias": "model-00006-of-00008.safetensors",
128
+ "transformer.h.22.attn.c_attn.weight": "model-00006-of-00008.safetensors",
129
+ "transformer.h.22.attn.c_proj.weight": "model-00006-of-00008.safetensors",
130
+ "transformer.h.22.ln_1.weight": "model-00006-of-00008.safetensors",
131
+ "transformer.h.22.ln_2.weight": "model-00006-of-00008.safetensors",
132
+ "transformer.h.22.mlp.c_proj.weight": "model-00006-of-00008.safetensors",
133
+ "transformer.h.22.mlp.w1.weight": "model-00006-of-00008.safetensors",
134
+ "transformer.h.22.mlp.w2.weight": "model-00006-of-00008.safetensors",
135
+ "transformer.h.23.attn.c_attn.bias": "model-00006-of-00008.safetensors",
136
+ "transformer.h.23.attn.c_attn.weight": "model-00006-of-00008.safetensors",
137
+ "transformer.h.23.attn.c_proj.weight": "model-00006-of-00008.safetensors",
138
+ "transformer.h.23.ln_1.weight": "model-00006-of-00008.safetensors",
139
+ "transformer.h.23.ln_2.weight": "model-00006-of-00008.safetensors",
140
+ "transformer.h.23.mlp.c_proj.weight": "model-00006-of-00008.safetensors",
141
+ "transformer.h.23.mlp.w1.weight": "model-00006-of-00008.safetensors",
142
+ "transformer.h.23.mlp.w2.weight": "model-00006-of-00008.safetensors",
143
+ "transformer.h.24.attn.c_attn.bias": "model-00006-of-00008.safetensors",
144
+ "transformer.h.24.attn.c_attn.weight": "model-00006-of-00008.safetensors",
145
+ "transformer.h.24.attn.c_proj.weight": "model-00006-of-00008.safetensors",
146
+ "transformer.h.24.ln_1.weight": "model-00006-of-00008.safetensors",
147
+ "transformer.h.24.ln_2.weight": "model-00006-of-00008.safetensors",
148
+ "transformer.h.24.mlp.c_proj.weight": "model-00006-of-00008.safetensors",
149
+ "transformer.h.24.mlp.w1.weight": "model-00006-of-00008.safetensors",
150
+ "transformer.h.24.mlp.w2.weight": "model-00006-of-00008.safetensors",
151
+ "transformer.h.25.attn.c_attn.bias": "model-00006-of-00008.safetensors",
152
+ "transformer.h.25.attn.c_attn.weight": "model-00006-of-00008.safetensors",
153
+ "transformer.h.25.attn.c_proj.weight": "model-00006-of-00008.safetensors",
154
+ "transformer.h.25.ln_1.weight": "model-00006-of-00008.safetensors",
155
+ "transformer.h.25.ln_2.weight": "model-00006-of-00008.safetensors",
156
+ "transformer.h.25.mlp.c_proj.weight": "model-00006-of-00008.safetensors",
157
+ "transformer.h.25.mlp.w1.weight": "model-00006-of-00008.safetensors",
158
+ "transformer.h.25.mlp.w2.weight": "model-00006-of-00008.safetensors",
159
+ "transformer.h.26.attn.c_attn.bias": "model-00006-of-00008.safetensors",
160
+ "transformer.h.26.attn.c_attn.weight": "model-00006-of-00008.safetensors",
161
+ "transformer.h.26.attn.c_proj.weight": "model-00006-of-00008.safetensors",
162
+ "transformer.h.26.ln_1.weight": "model-00006-of-00008.safetensors",
163
+ "transformer.h.26.ln_2.weight": "model-00006-of-00008.safetensors",
164
+ "transformer.h.26.mlp.c_proj.weight": "model-00007-of-00008.safetensors",
165
+ "transformer.h.26.mlp.w1.weight": "model-00006-of-00008.safetensors",
166
+ "transformer.h.26.mlp.w2.weight": "model-00006-of-00008.safetensors",
167
+ "transformer.h.27.attn.c_attn.bias": "model-00007-of-00008.safetensors",
168
+ "transformer.h.27.attn.c_attn.weight": "model-00007-of-00008.safetensors",
169
+ "transformer.h.27.attn.c_proj.weight": "model-00007-of-00008.safetensors",
170
+ "transformer.h.27.ln_1.weight": "model-00007-of-00008.safetensors",
171
+ "transformer.h.27.ln_2.weight": "model-00007-of-00008.safetensors",
172
+ "transformer.h.27.mlp.c_proj.weight": "model-00007-of-00008.safetensors",
173
+ "transformer.h.27.mlp.w1.weight": "model-00007-of-00008.safetensors",
174
+ "transformer.h.27.mlp.w2.weight": "model-00007-of-00008.safetensors",
175
+ "transformer.h.28.attn.c_attn.bias": "model-00007-of-00008.safetensors",
176
+ "transformer.h.28.attn.c_attn.weight": "model-00007-of-00008.safetensors",
177
+ "transformer.h.28.attn.c_proj.weight": "model-00007-of-00008.safetensors",
178
+ "transformer.h.28.ln_1.weight": "model-00007-of-00008.safetensors",
179
+ "transformer.h.28.ln_2.weight": "model-00007-of-00008.safetensors",
180
+ "transformer.h.28.mlp.c_proj.weight": "model-00007-of-00008.safetensors",
181
+ "transformer.h.28.mlp.w1.weight": "model-00007-of-00008.safetensors",
182
+ "transformer.h.28.mlp.w2.weight": "model-00007-of-00008.safetensors",
183
+ "transformer.h.29.attn.c_attn.bias": "model-00007-of-00008.safetensors",
184
+ "transformer.h.29.attn.c_attn.weight": "model-00007-of-00008.safetensors",
185
+ "transformer.h.29.attn.c_proj.weight": "model-00007-of-00008.safetensors",
186
+ "transformer.h.29.ln_1.weight": "model-00007-of-00008.safetensors",
187
+ "transformer.h.29.ln_2.weight": "model-00007-of-00008.safetensors",
188
+ "transformer.h.29.mlp.c_proj.weight": "model-00007-of-00008.safetensors",
189
+ "transformer.h.29.mlp.w1.weight": "model-00007-of-00008.safetensors",
190
+ "transformer.h.29.mlp.w2.weight": "model-00007-of-00008.safetensors",
191
+ "transformer.h.3.attn.c_attn.bias": "model-00002-of-00008.safetensors",
192
+ "transformer.h.3.attn.c_attn.weight": "model-00002-of-00008.safetensors",
193
+ "transformer.h.3.attn.c_proj.weight": "model-00002-of-00008.safetensors",
194
+ "transformer.h.3.ln_1.weight": "model-00002-of-00008.safetensors",
195
+ "transformer.h.3.ln_2.weight": "model-00002-of-00008.safetensors",
196
+ "transformer.h.3.mlp.c_proj.weight": "model-00002-of-00008.safetensors",
197
+ "transformer.h.3.mlp.w1.weight": "model-00002-of-00008.safetensors",
198
+ "transformer.h.3.mlp.w2.weight": "model-00002-of-00008.safetensors",
199
+ "transformer.h.30.attn.c_attn.bias": "model-00007-of-00008.safetensors",
200
+ "transformer.h.30.attn.c_attn.weight": "model-00007-of-00008.safetensors",
201
+ "transformer.h.30.attn.c_proj.weight": "model-00007-of-00008.safetensors",
202
+ "transformer.h.30.ln_1.weight": "model-00007-of-00008.safetensors",
203
+ "transformer.h.30.ln_2.weight": "model-00007-of-00008.safetensors",
204
+ "transformer.h.30.mlp.c_proj.weight": "model-00007-of-00008.safetensors",
205
+ "transformer.h.30.mlp.w1.weight": "model-00007-of-00008.safetensors",
206
+ "transformer.h.30.mlp.w2.weight": "model-00007-of-00008.safetensors",
207
+ "transformer.h.31.attn.c_attn.bias": "model-00007-of-00008.safetensors",
208
+ "transformer.h.31.attn.c_attn.weight": "model-00007-of-00008.safetensors",
209
+ "transformer.h.31.attn.c_proj.weight": "model-00007-of-00008.safetensors",
210
+ "transformer.h.31.ln_1.weight": "model-00007-of-00008.safetensors",
211
+ "transformer.h.31.ln_2.weight": "model-00007-of-00008.safetensors",
212
+ "transformer.h.31.mlp.c_proj.weight": "model-00008-of-00008.safetensors",
213
+ "transformer.h.31.mlp.w1.weight": "model-00007-of-00008.safetensors",
214
+ "transformer.h.31.mlp.w2.weight": "model-00007-of-00008.safetensors",
215
+ "transformer.h.4.attn.c_attn.bias": "model-00002-of-00008.safetensors",
216
+ "transformer.h.4.attn.c_attn.weight": "model-00002-of-00008.safetensors",
217
+ "transformer.h.4.attn.c_proj.weight": "model-00002-of-00008.safetensors",
218
+ "transformer.h.4.ln_1.weight": "model-00002-of-00008.safetensors",
219
+ "transformer.h.4.ln_2.weight": "model-00002-of-00008.safetensors",
220
+ "transformer.h.4.mlp.c_proj.weight": "model-00002-of-00008.safetensors",
221
+ "transformer.h.4.mlp.w1.weight": "model-00002-of-00008.safetensors",
222
+ "transformer.h.4.mlp.w2.weight": "model-00002-of-00008.safetensors",
223
+ "transformer.h.5.attn.c_attn.bias": "model-00002-of-00008.safetensors",
224
+ "transformer.h.5.attn.c_attn.weight": "model-00002-of-00008.safetensors",
225
+ "transformer.h.5.attn.c_proj.weight": "model-00002-of-00008.safetensors",
226
+ "transformer.h.5.ln_1.weight": "model-00002-of-00008.safetensors",
227
+ "transformer.h.5.ln_2.weight": "model-00002-of-00008.safetensors",
228
+ "transformer.h.5.mlp.c_proj.weight": "model-00002-of-00008.safetensors",
229
+ "transformer.h.5.mlp.w1.weight": "model-00002-of-00008.safetensors",
230
+ "transformer.h.5.mlp.w2.weight": "model-00002-of-00008.safetensors",
231
+ "transformer.h.6.attn.c_attn.bias": "model-00002-of-00008.safetensors",
232
+ "transformer.h.6.attn.c_attn.weight": "model-00002-of-00008.safetensors",
233
+ "transformer.h.6.attn.c_proj.weight": "model-00002-of-00008.safetensors",
234
+ "transformer.h.6.ln_1.weight": "model-00002-of-00008.safetensors",
235
+ "transformer.h.6.ln_2.weight": "model-00002-of-00008.safetensors",
236
+ "transformer.h.6.mlp.c_proj.weight": "model-00003-of-00008.safetensors",
237
+ "transformer.h.6.mlp.w1.weight": "model-00002-of-00008.safetensors",
238
+ "transformer.h.6.mlp.w2.weight": "model-00002-of-00008.safetensors",
239
+ "transformer.h.7.attn.c_attn.bias": "model-00003-of-00008.safetensors",
240
+ "transformer.h.7.attn.c_attn.weight": "model-00003-of-00008.safetensors",
241
+ "transformer.h.7.attn.c_proj.weight": "model-00003-of-00008.safetensors",
242
+ "transformer.h.7.ln_1.weight": "model-00003-of-00008.safetensors",
243
+ "transformer.h.7.ln_2.weight": "model-00003-of-00008.safetensors",
244
+ "transformer.h.7.mlp.c_proj.weight": "model-00003-of-00008.safetensors",
245
+ "transformer.h.7.mlp.w1.weight": "model-00003-of-00008.safetensors",
246
+ "transformer.h.7.mlp.w2.weight": "model-00003-of-00008.safetensors",
247
+ "transformer.h.8.attn.c_attn.bias": "model-00003-of-00008.safetensors",
248
+ "transformer.h.8.attn.c_attn.weight": "model-00003-of-00008.safetensors",
249
+ "transformer.h.8.attn.c_proj.weight": "model-00003-of-00008.safetensors",
250
+ "transformer.h.8.ln_1.weight": "model-00003-of-00008.safetensors",
251
+ "transformer.h.8.ln_2.weight": "model-00003-of-00008.safetensors",
252
+ "transformer.h.8.mlp.c_proj.weight": "model-00003-of-00008.safetensors",
253
+ "transformer.h.8.mlp.w1.weight": "model-00003-of-00008.safetensors",
254
+ "transformer.h.8.mlp.w2.weight": "model-00003-of-00008.safetensors",
255
+ "transformer.h.9.attn.c_attn.bias": "model-00003-of-00008.safetensors",
256
+ "transformer.h.9.attn.c_attn.weight": "model-00003-of-00008.safetensors",
257
+ "transformer.h.9.attn.c_proj.weight": "model-00003-of-00008.safetensors",
258
+ "transformer.h.9.ln_1.weight": "model-00003-of-00008.safetensors",
259
+ "transformer.h.9.ln_2.weight": "model-00003-of-00008.safetensors",
260
+ "transformer.h.9.mlp.c_proj.weight": "model-00003-of-00008.safetensors",
261
+ "transformer.h.9.mlp.w1.weight": "model-00003-of-00008.safetensors",
262
+ "transformer.h.9.mlp.w2.weight": "model-00003-of-00008.safetensors",
263
+ "transformer.ln_f.weight": "model-00008-of-00008.safetensors",
264
+ "transformer.wte.weight": "model-00001-of-00008.safetensors"
265
+ }
266
+ }
modeling_qwen.py ADDED
@@ -0,0 +1,1138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import importlib
7
+ import math
8
+ from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+ from torch.cuda.amp import autocast
14
+
15
+ from torch.nn import CrossEntropyLoss
16
+ from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
17
+ from transformers.generation.logits_process import LogitsProcessorList
18
+
19
+ if TYPE_CHECKING:
20
+ from transformers.generation.streamers import BaseStreamer
21
+ from transformers.generation.utils import GenerateOutput
22
+ from transformers.modeling_outputs import (
23
+ BaseModelOutputWithPast,
24
+ CausalLMOutputWithPast,
25
+ )
26
+ from transformers.modeling_utils import PreTrainedModel
27
+ from transformers.utils import logging
28
+
29
+ try:
30
+ from einops import rearrange
31
+ except ImportError:
32
+ rearrange = None
33
+ from torch import nn
34
+
35
+ SUPPORT_CUDA = torch.cuda.is_available()
36
+ SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
37
+ SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
38
+
39
+ apply_rotary_emb_func = None
40
+ rms_norm = None
41
+ flash_attn_unpadded_func = None
42
+
43
+ from .configuration_qwen import QWenConfig
44
+ from .qwen_generation_utils import (
45
+ HistoryType,
46
+ make_context,
47
+ decode_tokens,
48
+ get_stop_words_ids,
49
+ StopWordsLogitsProcessor,
50
+ )
51
+
52
+
53
+ logger = logging.get_logger(__name__)
54
+
55
+ _CHECKPOINT_FOR_DOC = "qwen"
56
+ _CONFIG_FOR_DOC = "QWenConfig"
57
+
58
+ QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
59
+
60
+ class FlashSelfAttention(torch.nn.Module):
61
+ def __init__(
62
+ self,
63
+ causal=False,
64
+ softmax_scale=None,
65
+ attention_dropout=0.0,
66
+ ):
67
+ super().__init__()
68
+ assert flash_attn_unpadded_func is not None, (
69
+ "Please install FlashAttention first, " "e.g., with pip install flash-attn"
70
+ )
71
+ assert (
72
+ rearrange is not None
73
+ ), "Please install einops first, e.g., with pip install einops"
74
+ self.causal = causal
75
+ self.softmax_scale = softmax_scale
76
+ self.dropout_p = attention_dropout
77
+
78
+ def forward(self, q, k, v):
79
+ assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v)))
80
+ assert all((i.is_cuda for i in (q, k, v)))
81
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
82
+ seqlen_k = k.shape[1]
83
+ q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
84
+ cu_seqlens_q = torch.arange(
85
+ 0,
86
+ (batch_size + 1) * seqlen_q,
87
+ step=seqlen_q,
88
+ dtype=torch.int32,
89
+ device=q.device,
90
+ )
91
+
92
+ if self.training:
93
+ assert seqlen_k == seqlen_q
94
+
95
+ is_causal = self.causal
96
+ cu_seqlens_k = cu_seqlens_q
97
+ else:
98
+ is_causal = seqlen_q == seqlen_k
99
+ cu_seqlens_k = torch.arange(
100
+ 0,
101
+ (batch_size + 1) * seqlen_k,
102
+ step=seqlen_k,
103
+ dtype=torch.int32,
104
+ device=q.device,
105
+ )
106
+ self.dropout_p = 0
107
+ output = flash_attn_unpadded_func(
108
+ q,
109
+ k,
110
+ v,
111
+ cu_seqlens_q,
112
+ cu_seqlens_k,
113
+ seqlen_q,
114
+ seqlen_k,
115
+ self.dropout_p,
116
+ softmax_scale=self.softmax_scale,
117
+ causal=is_causal,
118
+ )
119
+
120
+ output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
121
+ return output
122
+
123
+
124
+ class QWenAttention(nn.Module):
125
+ def __init__(self, config, layer_number=None):
126
+ super().__init__()
127
+
128
+ max_positions = config.max_position_embeddings
129
+ self.register_buffer(
130
+ "bias",
131
+ torch.tril(
132
+ torch.ones((max_positions, max_positions), dtype=torch.bool)
133
+ ).view(1, 1, max_positions, max_positions),
134
+ persistent=False,
135
+ )
136
+ self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
137
+ self.layer_number = max(1, layer_number)
138
+ self.params_dtype = config.params_dtype
139
+ self.seq_length = config.seq_length
140
+
141
+ self.hidden_size = config.hidden_size
142
+ self.split_size = config.hidden_size
143
+ self.num_heads = config.num_attention_heads
144
+ self.head_dim = self.hidden_size // self.num_heads
145
+
146
+ self.use_flash_attn = config.use_flash_attn
147
+ self.scale_attn_weights = True
148
+
149
+ self.layer_idx = None
150
+
151
+ self.projection_size = config.kv_channels * config.num_attention_heads
152
+
153
+ assert self.projection_size % config.num_attention_heads == 0
154
+ self.hidden_size_per_attention_head = (
155
+ self.projection_size // config.num_attention_heads
156
+ )
157
+
158
+ self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size)
159
+
160
+ self.c_proj = nn.Linear(
161
+ config.hidden_size, self.projection_size, bias=not config.no_bias
162
+ )
163
+
164
+ self.is_fp32 = not (config.bf16 or config.fp16)
165
+ if (
166
+ self.use_flash_attn
167
+ and flash_attn_unpadded_func is not None
168
+ and not self.is_fp32
169
+ ):
170
+ self.core_attention_flash = FlashSelfAttention(
171
+ causal=True, attention_dropout=config.attn_pdrop
172
+ )
173
+
174
+ self.bf16 = config.bf16
175
+
176
+ if config.rotary_pct == 1.0:
177
+ self.rotary_ndims = None
178
+ else:
179
+ assert config.rotary_pct < 1
180
+ self.rotary_ndims = int(
181
+ self.hidden_size_per_attention_head * config.rotary_pct
182
+ )
183
+ dim = (
184
+ self.rotary_ndims
185
+ if self.rotary_ndims is not None
186
+ else self.hidden_size_per_attention_head
187
+ )
188
+ self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
189
+
190
+ self.use_dynamic_ntk = config.use_dynamic_ntk
191
+ self.use_logn_attn = config.use_logn_attn
192
+
193
+ logn_list = [
194
+ math.log(i, self.seq_length) if i > self.seq_length else 1
195
+ for i in range(1, 32768)
196
+ ]
197
+ self.logn_tensor = torch.Tensor(logn_list)[None, :, None, None]
198
+ self._ntk_cached = 1.0
199
+
200
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
201
+
202
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
203
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
204
+
205
+ if self.scale_attn_weights:
206
+ attn_weights = attn_weights / torch.full(
207
+ [],
208
+ value.size(-1) ** 0.5,
209
+ dtype=attn_weights.dtype,
210
+ device=attn_weights.device,
211
+ )
212
+
213
+ query_length, key_length = query.size(-2), key.size(-2)
214
+ causal_mask = self.bias[
215
+ :, :, key_length - query_length : key_length, :key_length
216
+ ]
217
+ mask_value = torch.finfo(attn_weights.dtype).min
218
+ mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(
219
+ attn_weights.device
220
+ )
221
+ attn_weights = torch.where(
222
+ causal_mask, attn_weights.to(attn_weights.dtype), mask_value
223
+ )
224
+
225
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
226
+
227
+ attn_weights = attn_weights.type(value.dtype)
228
+ attn_weights = self.attn_dropout(attn_weights)
229
+
230
+ if head_mask is not None:
231
+ attn_weights = attn_weights * head_mask
232
+
233
+ attn_output = torch.matmul(attn_weights, value)
234
+ attn_output = attn_output.transpose(1, 2)
235
+ # print("attn_weights:", attn_weights)
236
+
237
+ return attn_output, attn_weights
238
+
239
+ def _upcast_and_reordered_attn(
240
+ self, query, key, value, attention_mask=None, head_mask=None
241
+ ):
242
+ bsz, num_heads, q_seq_len, dk = query.size()
243
+ _, _, k_seq_len, _ = key.size()
244
+
245
+ attn_weights = torch.empty(
246
+ bsz * num_heads,
247
+ q_seq_len,
248
+ k_seq_len,
249
+ dtype=torch.float32,
250
+ device=query.device,
251
+ )
252
+
253
+ scale_factor = 1.0
254
+ if self.scale_attn_weights:
255
+ scale_factor /= float(value.size(-1)) ** 0.5
256
+
257
+ with autocast(enabled=False):
258
+ q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(
259
+ -1, dk, k_seq_len
260
+ )
261
+ attn_weights = torch.baddbmm(
262
+ attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor
263
+ )
264
+ attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
265
+
266
+ query_length, key_length = query.size(-2), key.size(-2)
267
+ causal_mask = self.bias[
268
+ :, :, key_length - query_length : key_length, :key_length
269
+ ]
270
+ mask_value = torch.finfo(attn_weights.dtype).min
271
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(
272
+ attn_weights.device
273
+ )
274
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
275
+
276
+ if attention_mask is not None:
277
+ attn_weights = attn_weights + attention_mask
278
+
279
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
280
+
281
+ if attn_weights.dtype != torch.float32:
282
+ raise RuntimeError(
283
+ "Error with upcasting, attn_weights does not have dtype torch.float32"
284
+ )
285
+ attn_weights = attn_weights.type(value.dtype)
286
+ attn_weights = self.attn_dropout(attn_weights)
287
+
288
+ if head_mask is not None:
289
+ attn_weights = attn_weights * head_mask
290
+
291
+ attn_output = torch.matmul(attn_weights, value)
292
+
293
+ return attn_output, attn_weights
294
+
295
+ def _split_heads(self, tensor, num_heads, attn_head_size):
296
+ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
297
+ tensor = tensor.view(new_shape)
298
+ return tensor
299
+
300
+ def _merge_heads(self, tensor, num_heads, attn_head_size):
301
+ tensor = tensor.contiguous()
302
+ new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
303
+ return tensor.view(new_shape)
304
+
305
+ def forward(
306
+ self,
307
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
308
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
309
+ attention_mask: Optional[torch.FloatTensor] = None,
310
+ head_mask: Optional[torch.FloatTensor] = None,
311
+ encoder_hidden_states: Optional[torch.Tensor] = None,
312
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
313
+ output_attentions: Optional[bool] = False,
314
+ use_cache: Optional[bool] = False,
315
+ ):
316
+
317
+ mixed_x_layer = self.c_attn(hidden_states)
318
+ query, key, value = mixed_x_layer.split(self.split_size, dim=2)
319
+
320
+ query = self._split_heads(query, self.num_heads, self.head_dim)
321
+ key = self._split_heads(key, self.num_heads, self.head_dim)
322
+ value = self._split_heads(value, self.num_heads, self.head_dim)
323
+
324
+ kv_seq_len = hidden_states.size()[1]
325
+ if layer_past:
326
+ # layer past[0] shape: bs * seq_len * head_num * dim
327
+ kv_seq_len += layer_past[0].shape[1]
328
+ if (
329
+ self.use_dynamic_ntk
330
+ and kv_seq_len == hidden_states.size()[1]
331
+ and not self.training
332
+ ):
333
+ context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
334
+ ntk_alpha = 2 ** math.ceil(context_value) - 1
335
+ ntk_alpha = max(ntk_alpha, 1)
336
+ self._ntk_cached = ntk_alpha
337
+ else:
338
+ ntk_alpha = self._ntk_cached
339
+ rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha).to(
340
+ hidden_states.device
341
+ )
342
+
343
+ if rotary_pos_emb is not None:
344
+ if isinstance(rotary_pos_emb, tuple):
345
+ rotary_pos_emb = rotary_pos_emb
346
+ else:
347
+ rotary_pos_emb = (rotary_pos_emb,) * 2
348
+
349
+ if rotary_pos_emb is not None:
350
+ q_pos_emb, k_pos_emb = rotary_pos_emb
351
+ # Slice the pos emb for current inference
352
+ cur_len = query.shape[1]
353
+ q_pos_emb = q_pos_emb[:, -cur_len:, :, :]
354
+ k_pos_emb = k_pos_emb[:, -cur_len:, :, :]
355
+ query = apply_rotary_pos_emb(query, q_pos_emb)
356
+ key = apply_rotary_pos_emb(key, k_pos_emb)
357
+
358
+ if layer_past is not None:
359
+ past_key, past_value = layer_past[0], layer_past[1]
360
+ key = torch.cat((past_key, key), dim=1)
361
+ value = torch.cat((past_value, value), dim=1)
362
+
363
+ if use_cache:
364
+ present = (key, value)
365
+ else:
366
+ present = None
367
+
368
+ if self.use_logn_attn and not self.training:
369
+ if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype:
370
+ self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)
371
+ seq_start = key.size(1) - query.size(1)
372
+ seq_end = key.size(1)
373
+ logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
374
+ query = query * logn_tensor.expand_as(query)
375
+
376
+ if (
377
+ self.use_flash_attn
378
+ and flash_attn_unpadded_func is not None
379
+ and not self.is_fp32
380
+ and query.is_cuda
381
+ ):
382
+ q, k, v = query, key, value
383
+ context_layer = self.core_attention_flash(q, k, v)
384
+
385
+ context_layer = rearrange(
386
+ context_layer, "b s h d -> b s (h d)"
387
+ ).contiguous()
388
+ else:
389
+ query = query.permute(0, 2, 1, 3)
390
+ key = key.permute(0, 2, 1, 3)
391
+ value = value.permute(0, 2, 1, 3)
392
+ attn_output, attn_weight = self._attn(
393
+ query, key, value, attention_mask, head_mask
394
+ )
395
+ context_layer = self._merge_heads(
396
+ attn_output, self.num_heads, self.head_dim
397
+ )
398
+
399
+ attn_output = self.c_proj(context_layer)
400
+ outputs = (attn_output, present)
401
+ if output_attentions:
402
+ if (
403
+ self.use_flash_attn
404
+ and flash_attn_unpadded_func is not None
405
+ and not self.is_fp32
406
+ ):
407
+ raise ValueError("Cannot output attentions while using flash-attn")
408
+ else:
409
+ outputs += (attn_weight,)
410
+
411
+ return outputs
412
+
413
+
414
+ class QWenMLP(nn.Module):
415
+ def __init__(self, config):
416
+ super().__init__()
417
+ self.w1 = nn.Linear(
418
+ config.hidden_size, config.ffn_hidden_size // 2, bias=not config.no_bias
419
+ )
420
+ self.w2 = nn.Linear(
421
+ config.hidden_size, config.ffn_hidden_size // 2, bias=not config.no_bias
422
+ )
423
+ ff_dim_in = config.ffn_hidden_size // 2
424
+ self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias)
425
+
426
+ def forward(self, hidden_states):
427
+ a1 = self.w1(hidden_states)
428
+ a2 = self.w2(hidden_states)
429
+ intermediate_parallel = a1 * F.silu(a2)
430
+ output = self.c_proj(intermediate_parallel)
431
+ return output
432
+
433
+
434
+ class QWenBlock(nn.Module):
435
+ def __init__(self, config, layer_idx=None, num_expert=1):
436
+ super().__init__()
437
+ self.num_expert = num_expert
438
+ self.layer_number = layer_idx
439
+ self.apply_residual_connection_post_layernorm = (
440
+ config.apply_residual_connection_post_layernorm
441
+ )
442
+ hidden_size = config.hidden_size
443
+ self.apply_residual_connection_post_layernorm = (
444
+ config.apply_residual_connection_post_layernorm
445
+ )
446
+ self.bf16 = config.bf16
447
+
448
+ self.ln_1 = RMSNorm(
449
+ hidden_size,
450
+ eps=config.layer_norm_epsilon,
451
+ )
452
+ self.attn = QWenAttention(config, layer_number=layer_idx)
453
+ self.ln_2 = RMSNorm(
454
+ hidden_size,
455
+ eps=config.layer_norm_epsilon,
456
+ )
457
+
458
+ self.mlp = QWenMLP(config)
459
+
460
+ def forward(
461
+ self,
462
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
463
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
464
+ attention_mask: Optional[torch.FloatTensor] = None,
465
+ head_mask: Optional[torch.FloatTensor] = None,
466
+ encoder_hidden_states: Optional[torch.Tensor] = None,
467
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
468
+ use_cache: Optional[bool] = False,
469
+ output_attentions: Optional[bool] = False,
470
+ ):
471
+ layernorm_output = self.ln_1(hidden_states)
472
+
473
+ attn_outputs = self.attn(
474
+ layernorm_output,
475
+ layer_past=layer_past,
476
+ attention_mask=attention_mask,
477
+ head_mask=head_mask,
478
+ use_cache=use_cache,
479
+ output_attentions=output_attentions,
480
+ )
481
+ attn_output = attn_outputs[0]
482
+
483
+ outputs = attn_outputs[1:]
484
+
485
+ if self.apply_residual_connection_post_layernorm:
486
+ residual = layernorm_output
487
+ else:
488
+ residual = hidden_states
489
+ layernorm_input = attn_output + residual
490
+
491
+ layernorm_output = self.ln_2(layernorm_input)
492
+
493
+ if self.apply_residual_connection_post_layernorm:
494
+ residual = layernorm_output
495
+ else:
496
+ residual = layernorm_input
497
+
498
+ mlp_output = self.mlp(layernorm_output)
499
+ hidden_states = residual + mlp_output
500
+
501
+ if use_cache:
502
+ outputs = (hidden_states,) + outputs
503
+ else:
504
+ outputs = (hidden_states,) + outputs[1:]
505
+
506
+ return outputs
507
+
508
+
509
+ class QWenPreTrainedModel(PreTrainedModel):
510
+ config_class = QWenConfig
511
+ base_model_prefix = "transformer"
512
+ is_parallelizable = False
513
+ supports_gradient_checkpointing = True
514
+ _no_split_modules = ["QWenBlock"]
515
+
516
+ def __init__(self, *inputs, **kwargs):
517
+ super().__init__(*inputs, **kwargs)
518
+
519
+ def _init_weights(self, module):
520
+ """Initialize the weights."""
521
+ if isinstance(module, nn.Linear):
522
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
523
+ if module.bias is not None:
524
+ module.bias.data.zero_()
525
+ elif isinstance(module, nn.Embedding):
526
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
527
+ if module.padding_idx is not None:
528
+ module.weight.data[module.padding_idx].zero_()
529
+ elif isinstance(module, RMSNorm):
530
+ module.weight.data.fill_(1.0)
531
+
532
+ for name, p in module.named_parameters():
533
+ if name == "c_proj.weight":
534
+ p.data.normal_(
535
+ mean=0.0,
536
+ std=(
537
+ self.config.initializer_range
538
+ / math.sqrt(2 * self.config.n_layer)
539
+ ),
540
+ )
541
+
542
+ def _set_gradient_checkpointing(self, module, value=False):
543
+ if isinstance(module, QWenModel):
544
+ module.gradient_checkpointing = value
545
+
546
+
547
+ class QWenModel(QWenPreTrainedModel):
548
+ _keys_to_ignore_on_load_missing = ["attn.masked_bias"]
549
+
550
+ def __init__(self, config):
551
+ super().__init__(config)
552
+ self.vocab_size = config.padded_vocab_size
553
+ self.num_hidden_layers = config.num_hidden_layers
554
+ self.embed_dim = config.hidden_size
555
+
556
+ max_sequence_length = config.max_position_embeddings
557
+ self.position_embedding_type = config.pos_emb
558
+ self.gradient_checkpointing = False
559
+
560
+ if self.position_embedding_type == "learned":
561
+ self.wpe = nn.Embedding(max_sequence_length, self.embed_dim)
562
+ self.init_method(self.position_embeddings.weight)
563
+ self._position_embeddings_key = "position_embeddings"
564
+ self.init_method(self.position_embeddings.weight)
565
+ else:
566
+ self.wpe = None
567
+ self._position_embeddings_key = ""
568
+
569
+ self.wte = nn.Embedding(self.vocab_size, self.embed_dim)
570
+
571
+ self.drop = nn.Dropout(config.embd_pdrop)
572
+ self.h = nn.ModuleList(
573
+ [
574
+ QWenBlock(
575
+ config,
576
+ layer_idx=i,
577
+ )
578
+ for i in range(config.num_hidden_layers)
579
+ ]
580
+ )
581
+ self.ln_f = RMSNorm(
582
+ self.embed_dim,
583
+ eps=config.layer_norm_epsilon,
584
+ )
585
+
586
+ self.post_init()
587
+
588
+ def get_input_embeddings(self):
589
+ return self.wte
590
+
591
+ def set_input_embeddings(self, new_embeddings):
592
+ self.wte = new_embeddings
593
+
594
+ def forward(
595
+ self,
596
+ input_ids: Optional[torch.LongTensor] = None,
597
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
598
+ attention_mask: Optional[torch.FloatTensor] = None,
599
+ token_type_ids: Optional[torch.LongTensor] = None,
600
+ position_ids: Optional[torch.LongTensor] = None,
601
+ head_mask: Optional[torch.FloatTensor] = None,
602
+ inputs_embeds: Optional[torch.FloatTensor] = None,
603
+ encoder_hidden_states: Optional[torch.Tensor] = None,
604
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
605
+ use_cache: Optional[bool] = None,
606
+ output_attentions: Optional[bool] = None,
607
+ output_hidden_states: Optional[bool] = None,
608
+ return_dict: Optional[bool] = None,
609
+ ):
610
+ output_attentions = (
611
+ output_attentions
612
+ if output_attentions is not None
613
+ else self.config.output_attentions
614
+ )
615
+ output_hidden_states = (
616
+ output_hidden_states
617
+ if output_hidden_states is not None
618
+ else self.config.output_hidden_states
619
+ )
620
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
621
+ return_dict = (
622
+ return_dict if return_dict is not None else self.config.use_return_dict
623
+ )
624
+
625
+ if input_ids is not None and inputs_embeds is not None:
626
+ raise ValueError(
627
+ "You cannot specify both input_ids and inputs_embeds at the same time"
628
+ )
629
+ elif input_ids is not None:
630
+ input_shape = input_ids.size()
631
+ input_ids = input_ids.view(-1, input_shape[-1])
632
+ batch_size = input_ids.shape[0]
633
+ elif inputs_embeds is not None:
634
+ input_shape = inputs_embeds.size()[:-1]
635
+ batch_size = inputs_embeds.shape[0]
636
+ else:
637
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
638
+
639
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
640
+
641
+ if token_type_ids is not None:
642
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
643
+ if position_ids is not None:
644
+ position_ids = position_ids.view(-1, input_shape[-1])
645
+
646
+ if past_key_values is None:
647
+ past_length = 0
648
+ past_key_values = tuple([None] * len(self.h))
649
+ else:
650
+ past_length = past_key_values[0][0].size(-2)
651
+
652
+ if position_ids is None:
653
+ position_ids = torch.arange(
654
+ past_length,
655
+ input_shape[-1] + past_length,
656
+ dtype=torch.long,
657
+ device=device,
658
+ )
659
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
660
+
661
+ if attention_mask is not None:
662
+ if batch_size <= 0:
663
+ raise ValueError("batch_size has to be defined and > 0")
664
+ attention_mask = attention_mask.view(batch_size, -1)
665
+ attention_mask = attention_mask[:, None, None, :]
666
+ attention_mask = attention_mask.to(dtype=self.dtype)
667
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
668
+
669
+ encoder_attention_mask = None
670
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
671
+
672
+ if inputs_embeds is None:
673
+ inputs_embeds = self.wte(input_ids)
674
+ hidden_states = inputs_embeds
675
+ if self.wpe is not None:
676
+ position_embeds = self.wpe(position_ids)
677
+ hidden_states = hidden_states + position_embeds
678
+
679
+ hidden_states = self.drop(hidden_states)
680
+ output_shape = input_shape + (hidden_states.size(-1),)
681
+
682
+ if self.gradient_checkpointing and self.training:
683
+ if use_cache:
684
+ logger.warning_once(
685
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
686
+ )
687
+ use_cache = False
688
+
689
+ presents = () if use_cache else None
690
+ all_self_attentions = () if output_attentions else None
691
+ all_hidden_states = () if output_hidden_states else None
692
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
693
+
694
+ if output_hidden_states:
695
+ all_hidden_states = all_hidden_states + (hidden_states,)
696
+
697
+ if self.gradient_checkpointing and self.training:
698
+
699
+ def create_custom_forward(module):
700
+ def custom_forward(*inputs):
701
+ # None for past_key_value
702
+ return module(*inputs, use_cache, output_attentions)
703
+
704
+ return custom_forward
705
+
706
+ outputs = torch.utils.checkpoint.checkpoint(
707
+ create_custom_forward(block),
708
+ hidden_states,
709
+ None,
710
+ attention_mask,
711
+ head_mask[i],
712
+ encoder_hidden_states,
713
+ encoder_attention_mask,
714
+ )
715
+ else:
716
+ outputs = block(
717
+ hidden_states,
718
+ layer_past=layer_past,
719
+ attention_mask=attention_mask,
720
+ head_mask=head_mask[i],
721
+ encoder_hidden_states=encoder_hidden_states,
722
+ encoder_attention_mask=encoder_attention_mask,
723
+ use_cache=use_cache,
724
+ output_attentions=output_attentions,
725
+ )
726
+
727
+ hidden_states = outputs[0]
728
+ if use_cache is True:
729
+ presents = presents + (outputs[2 if output_attentions else 1],)
730
+
731
+ if output_attentions:
732
+ all_self_attentions = all_self_attentions + (outputs[1],)
733
+
734
+ hidden_states = self.ln_f(hidden_states)
735
+ hidden_states = hidden_states.view(output_shape)
736
+
737
+ if not return_dict:
738
+ return tuple(
739
+ v for v in [hidden_states, presents, all_hidden_states] if v is not None
740
+ )
741
+
742
+ return BaseModelOutputWithPast(
743
+ last_hidden_state=hidden_states,
744
+ past_key_values=presents,
745
+ hidden_states=all_hidden_states,
746
+ attentions=all_self_attentions,
747
+ )
748
+
749
+
750
+
751
+ class QWenLMHeadModel(QWenPreTrainedModel):
752
+ _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"]
753
+ _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"]
754
+
755
+ def __init__(self, config):
756
+ super().__init__(config)
757
+ assert (
758
+ config.bf16 + config.fp16 + config.fp32 <= 1
759
+ ), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true"
760
+
761
+ autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0
762
+
763
+ if autoset_precision:
764
+ if SUPPORT_BF16:
765
+ logger.warn(
766
+ "The model is automatically converting to bf16 for faster inference. "
767
+ "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
768
+ )
769
+ config.bf16 = True
770
+ elif SUPPORT_FP16:
771
+ logger.warn(
772
+ "The model is automatically converting to fp16 for faster inference. "
773
+ "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
774
+ )
775
+ config.fp16 = True
776
+ else:
777
+ config.fp32 = True
778
+
779
+ if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16:
780
+ logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".")
781
+ if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16:
782
+ logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster")
783
+ if config.fp32:
784
+ if SUPPORT_BF16:
785
+ logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".")
786
+ elif SUPPORT_FP16:
787
+ logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")
788
+
789
+ if config.use_flash_attn == "auto":
790
+ if config.bf16 or config.fp16:
791
+ logger.warn("Try importing flash-attention for faster inference...")
792
+ config.use_flash_attn = True
793
+ else:
794
+ config.use_flash_attn = False
795
+ if config.use_flash_attn and config.fp32:
796
+ logger.warn("Flash attention will be disabled because it does NOT support fp32.")
797
+
798
+ if config.use_flash_attn:
799
+ global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func
800
+ try:
801
+ from flash_attn.layers.rotary import apply_rotary_emb_func as __apply_rotary_emb_func
802
+ apply_rotary_emb_func = __apply_rotary_emb_func
803
+ except ImportError:
804
+ logger.warn(
805
+ "Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency "
806
+ "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary"
807
+ )
808
+
809
+ try:
810
+ from flash_attn.ops.rms_norm import rms_norm as __rms_norm
811
+ rms_norm = __rms_norm
812
+ except ImportError:
813
+ logger.warn(
814
+ "Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency "
815
+ "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm"
816
+ )
817
+
818
+ try:
819
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
820
+ flash_attn_unpadded_func = __flash_attn_unpadded_func
821
+ except ImportError:
822
+ logger.warn(
823
+ "Warning: import flash_attn fail, please install FlashAttention to get higher efficiency "
824
+ "https://github.com/Dao-AILab/flash-attention"
825
+ )
826
+
827
+ self.transformer = QWenModel(config)
828
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
829
+
830
+ if config.bf16:
831
+ self.transformer.bfloat16()
832
+ self.lm_head.bfloat16()
833
+ if config.fp16:
834
+ self.transformer.half()
835
+ self.lm_head.half()
836
+ self.post_init()
837
+
838
+ def get_output_embeddings(self):
839
+ return self.lm_head
840
+
841
+ def set_output_embeddings(self, new_embeddings):
842
+ self.lm_head = new_embeddings
843
+
844
+ def prepare_inputs_for_generation(
845
+ self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
846
+ ):
847
+ token_type_ids = kwargs.get("token_type_ids", None)
848
+ if past_key_values:
849
+ input_ids = input_ids[:, -1].unsqueeze(-1)
850
+ if token_type_ids is not None:
851
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
852
+
853
+ attention_mask = kwargs.get("attention_mask", None)
854
+ position_ids = kwargs.get("position_ids", None)
855
+
856
+ if attention_mask is not None and position_ids is None:
857
+ position_ids = attention_mask.long().cumsum(-1) - 1
858
+ position_ids.masked_fill_(attention_mask == 0, 1)
859
+ if past_key_values:
860
+ position_ids = position_ids[:, -1].unsqueeze(-1)
861
+ else:
862
+ position_ids = None
863
+
864
+ if inputs_embeds is not None and past_key_values is None:
865
+ model_inputs = {"inputs_embeds": inputs_embeds}
866
+ else:
867
+ model_inputs = {"input_ids": input_ids}
868
+
869
+ model_inputs.update(
870
+ {
871
+ "past_key_values": past_key_values,
872
+ "use_cache": kwargs.get("use_cache"),
873
+ "position_ids": position_ids,
874
+ "attention_mask": attention_mask,
875
+ "token_type_ids": token_type_ids,
876
+ }
877
+ )
878
+ return model_inputs
879
+
880
+ def forward(
881
+ self,
882
+ input_ids: Optional[torch.LongTensor] = None,
883
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
884
+ attention_mask: Optional[torch.FloatTensor] = None,
885
+ token_type_ids: Optional[torch.LongTensor] = None,
886
+ position_ids: Optional[torch.LongTensor] = None,
887
+ head_mask: Optional[torch.FloatTensor] = None,
888
+ inputs_embeds: Optional[torch.FloatTensor] = None,
889
+ encoder_hidden_states: Optional[torch.Tensor] = None,
890
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
891
+ labels: Optional[torch.LongTensor] = None,
892
+ use_cache: Optional[bool] = None,
893
+ output_attentions: Optional[bool] = None,
894
+ output_hidden_states: Optional[bool] = None,
895
+ return_dict: Optional[bool] = None,
896
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
897
+
898
+ return_dict = (
899
+ return_dict if return_dict is not None else self.config.use_return_dict
900
+ )
901
+
902
+ transformer_outputs = self.transformer(
903
+ input_ids,
904
+ past_key_values=past_key_values,
905
+ attention_mask=attention_mask,
906
+ token_type_ids=token_type_ids,
907
+ position_ids=position_ids,
908
+ head_mask=head_mask,
909
+ inputs_embeds=inputs_embeds,
910
+ encoder_hidden_states=encoder_hidden_states,
911
+ encoder_attention_mask=encoder_attention_mask,
912
+ use_cache=use_cache,
913
+ output_attentions=output_attentions,
914
+ output_hidden_states=output_hidden_states,
915
+ return_dict=return_dict,
916
+ )
917
+ hidden_states = transformer_outputs[0]
918
+
919
+ lm_logits = self.lm_head(hidden_states)
920
+
921
+ loss = None
922
+ if labels is not None:
923
+ labels = labels.to(lm_logits.device)
924
+ shift_logits = lm_logits[..., :-1, :].contiguous()
925
+ shift_labels = labels[..., 1:].contiguous()
926
+ loss_fct = CrossEntropyLoss()
927
+ loss = loss_fct(
928
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
929
+ )
930
+
931
+ if not return_dict:
932
+ output = (lm_logits,) + transformer_outputs[1:]
933
+ return ((loss,) + output) if loss is not None else output
934
+ returns = CausalLMOutputWithPast(
935
+ loss=loss,
936
+ logits=lm_logits,
937
+ past_key_values=transformer_outputs.past_key_values,
938
+ hidden_states=transformer_outputs.hidden_states,
939
+ attentions=transformer_outputs.attentions,
940
+ )
941
+ return returns
942
+
943
+ @staticmethod
944
+ def _reorder_cache(
945
+ past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
946
+ ) -> Tuple[Tuple[torch.Tensor]]:
947
+
948
+ return tuple(
949
+ tuple(
950
+ past_state.index_select(0, beam_idx.to(past_state.device))
951
+ for past_state in layer_past
952
+ )
953
+ for layer_past in past_key_values
954
+ )
955
+
956
+ def chat(
957
+ self,
958
+ tokenizer: PreTrainedTokenizer,
959
+ query: str,
960
+ history: Optional[HistoryType],
961
+ system: str = "You are a helpful assistant.",
962
+ append_history: bool = True,
963
+ stream: Optional[bool] = False
964
+ ) -> Tuple[str, HistoryType]:
965
+
966
+
967
+ if history is None:
968
+ history = []
969
+
970
+ raw_text, context_tokens = make_context(
971
+ tokenizer,
972
+ query,
973
+ history=history,
974
+ system=system,
975
+ max_window_size=6144,
976
+ chat_format=self.generation_config.chat_format,
977
+ )
978
+
979
+ stop_words_ids = get_stop_words_ids(
980
+ self.generation_config.chat_format, tokenizer
981
+ )
982
+ input_ids = torch.tensor([context_tokens]).to(self.device)
983
+ if stream:
984
+ assert self.generation_config.chat_format == 'chatml'
985
+ from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
986
+ self.__class__.generate = NewGenerationMixin.generate
987
+ self.__class__.sample_stream = NewGenerationMixin.sample_stream
988
+ stream_config = StreamGenerationConfig(**self.generation_config.to_dict(), do_stream=True)
989
+ def stream_generator():
990
+ outputs = []
991
+ for token in self.generate(input_ids, return_dict_in_generate=False, generation_config=stream_config):
992
+ outputs.append(token.item())
993
+ if outputs[-1] in (tokenizer.im_end_id, tokenizer.im_start_id):
994
+ break
995
+ yield tokenizer.decode(outputs, skip_special_tokens=True)
996
+
997
+ return stream_generator()
998
+ else:
999
+ outputs = self.generate(
1000
+ input_ids,
1001
+ stop_words_ids = stop_words_ids,
1002
+ return_dict_in_generate=False
1003
+ )
1004
+
1005
+ response = decode_tokens(
1006
+ outputs[0],
1007
+ tokenizer,
1008
+ raw_text_len=len(raw_text),
1009
+ context_length=len(context_tokens),
1010
+ chat_format=self.generation_config.chat_format,
1011
+ verbose=False,
1012
+ )
1013
+
1014
+ if append_history:
1015
+ history.append((query, response))
1016
+
1017
+ return response, history
1018
+
1019
+ def generate(
1020
+ self,
1021
+ inputs: Optional[torch.Tensor] = None,
1022
+ generation_config: Optional[GenerationConfig] = None,
1023
+ logits_processor: Optional[LogitsProcessorList] = None,
1024
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
1025
+ prefix_allowed_tokens_fn: Optional[
1026
+ Callable[[int, torch.Tensor], List[int]]
1027
+ ] = None,
1028
+ synced_gpus: Optional[bool] = None,
1029
+ streamer: Optional["BaseStreamer"] = None,
1030
+ **kwargs,
1031
+ ) -> Union[GenerateOutput, torch.LongTensor]:
1032
+ # Process stop_words_ids.
1033
+ stop_words_ids = kwargs.pop("stop_words_ids", None)
1034
+ if stop_words_ids is None and generation_config is not None:
1035
+ stop_words_ids = getattr(generation_config, "stop_words_ids", None)
1036
+ if stop_words_ids is None:
1037
+ stop_words_ids = getattr(self.generation_config, "stop_words_ids", None)
1038
+
1039
+ if stop_words_ids is not None:
1040
+ stop_words_logits_processor = StopWordsLogitsProcessor(
1041
+ stop_words_ids=stop_words_ids,
1042
+ eos_token_id=self.generation_config.eos_token_id,
1043
+ )
1044
+ if logits_processor is None:
1045
+ logits_processor = LogitsProcessorList([stop_words_logits_processor])
1046
+ else:
1047
+ logits_processor.append(stop_words_logits_processor)
1048
+
1049
+ return super().generate(
1050
+ inputs,
1051
+ generation_config,
1052
+ logits_processor,
1053
+ stopping_criteria,
1054
+ prefix_allowed_tokens_fn,
1055
+ synced_gpus,
1056
+ streamer,
1057
+ **kwargs,
1058
+ )
1059
+
1060
+
1061
+ class RotaryEmbedding(torch.nn.Module):
1062
+ def __init__(self, dim, base=10000):
1063
+ super().__init__()
1064
+ self.dim = dim
1065
+ self.base = base
1066
+ self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
1067
+ if importlib.util.find_spec("einops") is None:
1068
+ raise RuntimeError("einops is required for Rotary Embedding")
1069
+
1070
+ self._rotary_pos_emb_cache = None
1071
+ self._seq_len_cached = 0
1072
+ self._ntk_alpha_cached = 1.0
1073
+
1074
+ def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0):
1075
+ seqlen = max_seq_len + offset
1076
+ if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
1077
+ base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
1078
+ self.inv_freq = 1.0 / (
1079
+ base
1080
+ ** (
1081
+ torch.arange(0, self.dim, 2, device=self.inv_freq.device).float()
1082
+ / self.dim
1083
+ )
1084
+ )
1085
+ self._seq_len_cached = seqlen
1086
+ self._ntk_alpha_cached = ntk_alpha
1087
+ seq = torch.arange(seqlen, device=self.inv_freq.device)
1088
+ freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
1089
+ emb = torch.cat((freqs, freqs), dim=-1)
1090
+ from einops import rearrange
1091
+
1092
+ self._rotary_pos_emb_cache = rearrange(emb, "n d -> 1 n 1 d")
1093
+
1094
+ def forward(self, max_seq_len, offset=0, ntk_alpha=1.0):
1095
+ self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha)
1096
+ return self._rotary_pos_emb_cache[:, offset : offset + max_seq_len]
1097
+
1098
+
1099
+ def _rotate_half(x):
1100
+ from einops import rearrange
1101
+
1102
+ x = rearrange(x, "... (j d) -> ... j d", j=2)
1103
+ x1, x2 = x.unbind(dim=-2)
1104
+ return torch.cat((-x2, x1), dim=-1)
1105
+
1106
+
1107
+ def apply_rotary_pos_emb(t, freqs):
1108
+ if apply_rotary_emb_func is not None:
1109
+ t_ = t.float()
1110
+ freqs = freqs.squeeze(0).squeeze(1)
1111
+ cos = freqs[:, : freqs.shape[-1] // 2].cos()
1112
+ sin = freqs[:, : freqs.shape[-1] // 2].sin()
1113
+ output = apply_rotary_emb_func(t_, cos, sin).type_as(t)
1114
+ return output
1115
+ else:
1116
+ rot_dim = freqs.shape[-1]
1117
+ t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
1118
+ t_ = t_.float()
1119
+ t_pass_ = t_pass_.float()
1120
+ t_ = (t_ * freqs.cos()) + (_rotate_half(t_) * freqs.sin())
1121
+ return torch.cat((t_, t_pass_), dim=-1).type_as(t)
1122
+
1123
+
1124
+ class RMSNorm(torch.nn.Module):
1125
+ def __init__(self, dim: int, eps: float = 1e-6):
1126
+ super().__init__()
1127
+ self.eps = eps
1128
+ self.weight = nn.Parameter(torch.ones(dim))
1129
+
1130
+ def _norm(self, x):
1131
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
1132
+
1133
+ def forward(self, x):
1134
+ if rms_norm is not None and x.is_cuda:
1135
+ return rms_norm(x, self.weight, self.eps)
1136
+ else:
1137
+ output = self._norm(x.float()).type_as(x)
1138
+ return output * self.weight
qwen.tiktoken ADDED
The diff for this file is too large to render. See raw diff
 
qwen_generation_utils.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """Generation support."""
7
+
8
+ from typing import Tuple, List, Union, Iterable
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from transformers import PreTrainedTokenizer
14
+ from transformers import logging
15
+ from transformers.generation import LogitsProcessor
16
+
17
+ logger = logging.get_logger(__name__)
18
+
19
+ # Types.
20
+ HistoryType = List[Tuple[str, str]]
21
+ TokensType = List[int]
22
+ BatchTokensType = List[List[int]]
23
+
24
+
25
+ def pad_batch(batch: BatchTokensType, pad_id: int, seq_length: int) -> BatchTokensType:
26
+ for tokens in batch:
27
+ context_length = len(tokens)
28
+ if context_length < seq_length:
29
+ tokens.extend([pad_id] * (seq_length - context_length))
30
+ return batch
31
+
32
+
33
+ def get_ltor_masks_and_position_ids(
34
+ data,
35
+ eod_token,
36
+ reset_position_ids,
37
+ reset_attention_mask,
38
+ eod_mask_loss,
39
+ ):
40
+ """Build masks and position id for left to right model."""
41
+
42
+ # Extract batch size and sequence length.
43
+ micro_batch_size, seq_length = data.size()
44
+
45
+ # Attention mask (lower triangular).
46
+ if reset_attention_mask:
47
+ att_mask_batch = micro_batch_size
48
+ else:
49
+ att_mask_batch = 1
50
+ attention_mask = torch.tril(
51
+ torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
52
+ ).view(att_mask_batch, 1, seq_length, seq_length)
53
+
54
+ # Loss mask.
55
+ loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
56
+ if eod_mask_loss:
57
+ loss_mask[data == eod_token] = 0.0
58
+
59
+ # Position ids.
60
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
61
+ position_ids = position_ids.unsqueeze(0).expand_as(data)
62
+ # We need to clone as the ids will be modifed based on batch index.
63
+ if reset_position_ids:
64
+ position_ids = position_ids.clone()
65
+
66
+ if reset_position_ids or reset_attention_mask:
67
+ # Loop through the batches:
68
+ for b in range(micro_batch_size):
69
+
70
+ # Find indecies where EOD token is.
71
+ eod_index = position_ids[b, data[b] == eod_token]
72
+ # Detach indecies from positions if going to modify positions.
73
+ if reset_position_ids:
74
+ eod_index = eod_index.clone()
75
+
76
+ # Loop through EOD indecies:
77
+ prev_index = 0
78
+ for j in range(eod_index.size()[0]):
79
+ i = eod_index[j]
80
+ # Mask attention loss.
81
+ if reset_attention_mask:
82
+ attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0
83
+ # Reset positions.
84
+ if reset_position_ids:
85
+ position_ids[b, (i + 1) :] -= i + 1 - prev_index
86
+ prev_index = i + 1
87
+
88
+ # Convert attention mask to binary:
89
+ attention_mask = attention_mask < 0.5
90
+
91
+ return attention_mask, loss_mask, position_ids
92
+
93
+
94
+ def get_batch(context_tokens: torch.LongTensor, eod_id: int):
95
+ """Generate batch from context tokens."""
96
+ # Move to GPU.
97
+ tokens = context_tokens.contiguous().to(context_tokens.device)
98
+ # Get the attention mask and postition ids.
99
+ attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
100
+ tokens,
101
+ eod_id,
102
+ reset_position_ids=False,
103
+ reset_attention_mask=False,
104
+ eod_mask_loss=False,
105
+ )
106
+ return tokens, attention_mask, position_ids
107
+
108
+
109
+ def get_stop_words_ids(chat_format, tokenizer):
110
+ if chat_format == "raw":
111
+ stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]]
112
+ elif chat_format == "chatml":
113
+ stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
114
+ else:
115
+ raise NotImplementedError(f"Unknown chat format {chat_format!r}")
116
+ return stop_words_ids
117
+
118
+
119
+ def make_context(
120
+ tokenizer: PreTrainedTokenizer,
121
+ query: str,
122
+ history: List[Tuple[str, str]] = None,
123
+ system: str = "",
124
+ max_window_size: int = 6144,
125
+ chat_format: str = "chatml",
126
+ ):
127
+ if history is None:
128
+ history = []
129
+
130
+ if chat_format == "chatml":
131
+ im_start, im_end = "<|im_start|>", "<|im_end|>"
132
+ im_start_tokens = [tokenizer.im_start_id]
133
+ im_end_tokens = [tokenizer.im_end_id]
134
+ nl_tokens = tokenizer.encode("\n")
135
+
136
+ def _tokenize_str(role, content):
137
+ return f"{role}\n{content}", tokenizer.encode(
138
+ role, allowed_special=set()
139
+ ) + nl_tokens + tokenizer.encode(content, allowed_special=set())
140
+
141
+ system_text, system_tokens_part = _tokenize_str("system", system)
142
+ system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
143
+
144
+ raw_text = ""
145
+ context_tokens = []
146
+
147
+ for turn_query, turn_response in reversed(history):
148
+ query_text, query_tokens_part = _tokenize_str("user", turn_query)
149
+ query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
150
+ response_text, response_tokens_part = _tokenize_str(
151
+ "assistant", turn_response
152
+ )
153
+ response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
154
+
155
+ next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
156
+ prev_chat = (
157
+ f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
158
+ )
159
+
160
+ current_context_size = (
161
+ len(system_tokens) + len(next_context_tokens) + len(context_tokens)
162
+ )
163
+ if current_context_size < max_window_size:
164
+ context_tokens = next_context_tokens + context_tokens
165
+ raw_text = prev_chat + raw_text
166
+ else:
167
+ break
168
+
169
+ context_tokens = system_tokens + context_tokens
170
+ raw_text = f"{im_start}{system_text}{im_end}" + raw_text
171
+ context_tokens += (
172
+ nl_tokens
173
+ + im_start_tokens
174
+ + _tokenize_str("user", query)[1]
175
+ + im_end_tokens
176
+ + nl_tokens
177
+ + im_start_tokens
178
+ + tokenizer.encode("assistant")
179
+ + nl_tokens
180
+ )
181
+ raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
182
+
183
+ elif chat_format == "raw":
184
+ raw_text = query
185
+ context_tokens = tokenizer.encode(raw_text)
186
+ else:
187
+ raise NotImplementedError(f"Unknown chat format {chat_format!r}")
188
+
189
+ return raw_text, context_tokens
190
+
191
+
192
+ def _decode_default(
193
+ tokens: List[int],
194
+ *,
195
+ stop_words: List[str],
196
+ eod_words: List[str],
197
+ tokenizer: PreTrainedTokenizer,
198
+ raw_text_len: int,
199
+ verbose: bool = False,
200
+ return_end_reason: bool = False,
201
+ ):
202
+ trim_decode_tokens = tokenizer.decode(tokens)[raw_text_len:]
203
+ if verbose:
204
+ print("\nRaw Generate: ", trim_decode_tokens)
205
+
206
+ end_reason = f"Gen length {len(tokens)}"
207
+ for stop_word in stop_words:
208
+ trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
209
+ for eod_word in eod_words:
210
+ if eod_word in trim_decode_tokens:
211
+ end_reason = f"Gen {eod_word!r}"
212
+ trim_decode_tokens = trim_decode_tokens.split(eod_word)[0]
213
+ trim_decode_tokens = trim_decode_tokens.strip()
214
+ if verbose:
215
+ print("\nEnd Reason:", end_reason)
216
+ print("\nGenerate: ", trim_decode_tokens)
217
+
218
+ if return_end_reason:
219
+ return trim_decode_tokens, end_reason
220
+ else:
221
+ return trim_decode_tokens
222
+
223
+
224
+ def _decode_chatml(
225
+ tokens: List[int],
226
+ *,
227
+ stop_words: List[str],
228
+ eod_token_ids: List[int],
229
+ tokenizer: PreTrainedTokenizer,
230
+ raw_text_len: int,
231
+ context_length: int,
232
+ verbose: bool = False,
233
+ return_end_reason: bool = False,
234
+ ):
235
+ end_reason = f"Gen length {len(tokens)}"
236
+ eod_token_idx = context_length
237
+ for eod_token_idx in range(context_length, len(tokens)):
238
+ if tokens[eod_token_idx] in eod_token_ids:
239
+ end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}"
240
+ break
241
+
242
+ trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx])[raw_text_len:]
243
+ if verbose:
244
+ print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens)[raw_text_len:])
245
+ print("\nRaw Generate:", trim_decode_tokens)
246
+ print("\nEnd Reason:", end_reason)
247
+ for stop_word in stop_words:
248
+ trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
249
+ trim_decode_tokens = trim_decode_tokens.strip()
250
+ if verbose:
251
+ print("\nGenerate:", trim_decode_tokens)
252
+
253
+ if return_end_reason:
254
+ return trim_decode_tokens, end_reason
255
+ else:
256
+ return trim_decode_tokens
257
+
258
+
259
+ def decode_tokens(
260
+ tokens: Union[torch.LongTensor, TokensType],
261
+ tokenizer: PreTrainedTokenizer,
262
+ raw_text_len: int,
263
+ context_length: int,
264
+ chat_format: str,
265
+ verbose: bool = False,
266
+ return_end_reason: bool = False,
267
+ ) -> str:
268
+ if torch.is_tensor(tokens):
269
+ tokens = tokens.cpu().numpy().tolist()
270
+
271
+ if chat_format == "chatml":
272
+ return _decode_chatml(
273
+ tokens,
274
+ stop_words=[],
275
+ eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id],
276
+ tokenizer=tokenizer,
277
+ raw_text_len=raw_text_len,
278
+ context_length=context_length,
279
+ verbose=verbose,
280
+ return_end_reason=return_end_reason,
281
+ )
282
+ elif chat_format == "raw":
283
+ return _decode_default(
284
+ tokens,
285
+ stop_words=["<|endoftext|>"],
286
+ eod_words=["<|endoftext|>"],
287
+ tokenizer=tokenizer,
288
+ raw_text_len=raw_text_len,
289
+ verbose=verbose,
290
+ return_end_reason=return_end_reason,
291
+ )
292
+ else:
293
+ raise NotImplementedError(f"Unknown chat format {chat_format!r}")
294
+
295
+
296
+ class StopWordsLogitsProcessor(LogitsProcessor):
297
+ """
298
+ :class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration.
299
+
300
+ Args:
301
+ stop_words_ids (:obj:`List[List[int]]`):
302
+ List of list of token ids of stop ids. In order to get the tokens of the words
303
+ that should not appear in the generated text, use :obj:`tokenizer(bad_word,
304
+ add_prefix_space=True).input_ids`.
305
+ eos_token_id (:obj:`int`):
306
+ The id of the `end-of-sequence` token.
307
+ """
308
+
309
+ def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int):
310
+
311
+ if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0:
312
+ raise ValueError(
313
+ f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}."
314
+ )
315
+ if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids):
316
+ raise ValueError(
317
+ f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}."
318
+ )
319
+ if any(
320
+ any(
321
+ (not isinstance(token_id, (int, np.integer)) or token_id < 0)
322
+ for token_id in stop_word_ids
323
+ )
324
+ for stop_word_ids in stop_words_ids
325
+ ):
326
+ raise ValueError(
327
+ f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}."
328
+ )
329
+
330
+ self.stop_words_ids = list(
331
+ filter(
332
+ lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids
333
+ )
334
+ )
335
+ self.eos_token_id = eos_token_id
336
+ for stop_token_seq in self.stop_words_ids:
337
+ assert (
338
+ len(stop_token_seq) > 0
339
+ ), "Stop words token sequences {} cannot have an empty list".format(
340
+ stop_words_ids
341
+ )
342
+
343
+ def __call__(
344
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor
345
+ ) -> torch.FloatTensor:
346
+ stopped_samples = self._calc_stopped_samples(input_ids)
347
+ for i, should_stop in enumerate(stopped_samples):
348
+ if should_stop:
349
+ scores[i, self.eos_token_id] = float(2**15)
350
+ return scores
351
+
352
+ def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool:
353
+ if len(tokens) == 0:
354
+ # if bad word tokens is just one token always ban it
355
+ return True
356
+ elif len(tokens) > len(prev_tokens):
357
+ # if bad word tokens are longer then prev input_ids they can't be equal
358
+ return False
359
+ elif prev_tokens[-len(tokens) :].tolist() == tokens:
360
+ # if tokens match
361
+ return True
362
+ else:
363
+ return False
364
+
365
+ def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iterable[int]:
366
+ stopped_samples = []
367
+ for prev_input_ids_slice in prev_input_ids:
368
+ match = False
369
+ for stop_token_seq in self.stop_words_ids:
370
+ if self._tokens_match(prev_input_ids_slice, stop_token_seq):
371
+ # if tokens do not match continue
372
+ match = True
373
+ break
374
+ stopped_samples.append(match)
375
+
376
+ return stopped_samples
377
+
378
+
379
+ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
380
+ """This function has been mostly taken from huggingface conversational
381
+ ai code at
382
+ https://medium.com/huggingface/how-to-build-a-state-of-the-art-
383
+ conversational-ai-with-transfer-learning-2d818ac26313"""
384
+
385
+ if top_k > 0:
386
+ # Remove all tokens with a probability less than the
387
+ # last token of the top-k
388
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
389
+ logits[indices_to_remove] = filter_value
390
+
391
+ if top_p > 0.0:
392
+ # Cconvert to 1D
393
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
394
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
395
+
396
+ # Remove tokens with cumulative probability above the threshold
397
+ sorted_indices_to_remove = cumulative_probs > top_p
398
+ # Shift the indices to the right to keep also the first token
399
+ # above the threshold
400
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
401
+ sorted_indices_to_remove[..., 0] = 0
402
+ for i in range(sorted_indices.size(0)):
403
+ indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
404
+ logits[i][indices_to_remove] = filter_value
405
+
406
+ return logits
407
+
408
+
409
+ def switch(val1, val2, boolean):
410
+ boolean = boolean.type_as(val1)
411
+ return (1 - boolean) * val1 + boolean * val2
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
tokenization_qwen.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """Tokenization classes for QWen."""
7
+
8
+ import base64
9
+ import logging
10
+ import os
11
+ import unicodedata
12
+ from typing import Collection, Dict, List, Set, Tuple, Union
13
+
14
+ import tiktoken
15
+ from transformers import PreTrainedTokenizer, AddedToken
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken"}
21
+
22
+ PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
23
+ ENDOFTEXT = "<|endoftext|>"
24
+ IMSTART = "<|im_start|>"
25
+ IMEND = "<|im_end|>"
26
+ # as the default behavior is changed to allow special tokens in
27
+ # regular texts, the surface forms of special tokens need to be
28
+ # as different as possible to minimize the impact
29
+ EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
30
+ SPECIAL_TOKENS = (
31
+ ENDOFTEXT,
32
+ IMSTART,
33
+ IMEND,
34
+ ) + EXTRAS
35
+
36
+
37
+ def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
38
+ contents = open(tiktoken_bpe_file, "rb").read()
39
+ return {
40
+ base64.b64decode(token): int(rank)
41
+ for token, rank in (line.split() for line in contents.splitlines() if line)
42
+ }
43
+
44
+ class QWenTokenizer(PreTrainedTokenizer):
45
+ """QWen tokenizer."""
46
+
47
+ vocab_files_names = VOCAB_FILES_NAMES
48
+
49
+ def __init__(
50
+ self,
51
+ vocab_file,
52
+ errors="replace",
53
+ **kwargs,
54
+ ):
55
+ super().__init__(**kwargs)
56
+
57
+ self.errors = errors # how to handle errors in decoding
58
+
59
+ self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int]
60
+ self.special_tokens = {
61
+ token: index
62
+ for index, token in enumerate(
63
+ SPECIAL_TOKENS, start=len(self.mergeable_ranks)
64
+ )
65
+ }
66
+
67
+ enc = tiktoken.Encoding(
68
+ "Qwen",
69
+ pat_str=PAT_STR,
70
+ mergeable_ranks=self.mergeable_ranks,
71
+ special_tokens=self.special_tokens,
72
+ )
73
+ assert (
74
+ len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
75
+ ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
76
+
77
+ self.decoder = {
78
+ v: k for k, v in self.mergeable_ranks.items()
79
+ } # type: dict[int, bytes|str]
80
+ self.decoder.update({v: k for k, v in self.special_tokens.items()})
81
+
82
+ self.tokenizer = enc # type: tiktoken.Encoding
83
+
84
+ self.eod_id = self.tokenizer.eot_token
85
+ self.im_start_id = self.special_tokens[IMSTART]
86
+ self.im_end_id = self.special_tokens[IMEND]
87
+
88
+ def __len__(self) -> int:
89
+ return self.tokenizer.n_vocab
90
+
91
+ def get_vocab(self) -> Dict[bytes, int]:
92
+ return self.mergeable_ranks
93
+
94
+ def convert_tokens_to_ids(
95
+ self, tokens: Union[bytes, str, List[Union[bytes, str]]]
96
+ ) -> List[int]:
97
+ ids = []
98
+ if isinstance(tokens, (str, bytes)):
99
+ if tokens in self.special_tokens:
100
+ return self.special_tokens[tokens]
101
+ else:
102
+ return self.mergeable_ranks.get(tokens)
103
+ for token in tokens:
104
+ if token in self.special_tokens:
105
+ ids.append(self.special_tokens[token])
106
+ else:
107
+ ids.append(self.mergeable_ranks.get(token))
108
+ return ids
109
+
110
+ def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
111
+ if not special_tokens and new_tokens:
112
+ raise ValueError('Adding regular tokens is not supported')
113
+ for token in new_tokens:
114
+ surface_form = token.content if isinstance(token, AddedToken) else token
115
+ if surface_form not in SPECIAL_TOKENS:
116
+ raise ValueError('Adding unknown special tokens is not supported')
117
+ return 0
118
+
119
+ def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
120
+ """
121
+ Save only the vocabulary of the tokenizer (vocabulary).
122
+
123
+ Returns:
124
+ `Tuple(str)`: Paths to the files saved.
125
+ """
126
+ file_path = os.path.join(save_directory, "qwen.tiktoken")
127
+ with open(file_path, "w", encoding="utf8") as w:
128
+ for k, v in self.mergeable_ranks.items():
129
+ line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n"
130
+ w.write(line)
131
+ return (file_path,)
132
+
133
+ def tokenize(
134
+ self,
135
+ text: str,
136
+ allowed_special: Union[Set, str] = "all",
137
+ disallowed_special: Union[Collection, str] = (),
138
+ **kwargs,
139
+ ) -> List[Union[bytes, str]]:
140
+ """
141
+ Converts a string in a sequence of tokens.
142
+
143
+ Args:
144
+ text (`str`):
145
+ The sequence to be encoded.
146
+ allowed_special (`Literal["all"]` or `set`):
147
+ The surface forms of the tokens to be encoded as special tokens in regular texts.
148
+ Default to "all".
149
+ disallowed_special (`Literal["all"]` or `Collection`):
150
+ The surface forms of the tokens that should not be in regular texts and trigger errors.
151
+ Default to an empty tuple.
152
+
153
+ kwargs (additional keyword arguments, *optional*):
154
+ Will be passed to the underlying model specific encode method.
155
+
156
+ Returns:
157
+ `List[bytes|str]`: The list of tokens.
158
+ """
159
+ tokens = []
160
+ text = unicodedata.normalize("NFC", text)
161
+
162
+ # this implementation takes a detour: text -> token id -> token surface forms
163
+ for t in self.tokenizer.encode(
164
+ text, allowed_special=allowed_special, disallowed_special=disallowed_special
165
+ ):
166
+ tokens.append(self.decoder[t])
167
+ return tokens
168
+
169
+ def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
170
+ """
171
+ Converts a sequence of tokens in a single string.
172
+ """
173
+ text = ""
174
+ temp = b""
175
+ for t in tokens:
176
+ if isinstance(t, str):
177
+ if temp:
178
+ text += temp.decode("utf-8", errors=self.errors)
179
+ temp = b""
180
+ text += t
181
+ elif isinstance(t, bytes):
182
+ temp += t
183
+ else:
184
+ raise TypeError("token should only be of type types or str")
185
+ if temp:
186
+ text += temp.decode("utf-8", errors=self.errors)
187
+ return text
188
+
189
+ @property
190
+ def vocab_size(self):
191
+ return self.tokenizer.n_vocab
192
+
193
+ def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
194
+ """Converts an id to a token, special tokens included"""
195
+ if index in self.decoder:
196
+ return self.decoder[index]
197
+ raise ValueError("unknown ids")
198
+
199
+ def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
200
+ """Converts a token to an id using the vocab, special tokens included"""
201
+ if token in self.special_tokens:
202
+ return self.special_tokens[token]
203
+ if token in self.mergeable_ranks:
204
+ return self.mergeable_ranks[token]
205
+ raise ValueError("unknown token")
206
+
207
+ def _tokenize(self, text: str, **kwargs):
208
+ """
209
+ Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
210
+ vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
211
+
212
+ Do NOT take care of added tokens.
213
+ """
214
+ raise NotImplementedError
215
+
216
+ def _decode(
217
+ self,
218
+ token_ids: Union[int, List[int]],
219
+ skip_special_tokens: bool = False,
220
+ **kwargs,
221
+ ) -> str:
222
+ if isinstance(token_ids, int):
223
+ token_ids = [token_ids]
224
+ if skip_special_tokens:
225
+ token_ids = [i for i in token_ids if i < self.eod_id]
226
+ return self.tokenizer.decode(token_ids, errors=self.errors)
tokenizer_config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {},
3
+ "additional_special_tokens": [],
4
+ "auto_map": {
5
+ "AutoTokenizer": [
6
+ "tokenization_qwen.QWenTokenizer",
7
+ null
8
+ ]
9
+ },
10
+ "clean_up_tokenization_spaces": true,
11
+ "model_max_length": 2048,
12
+ "padding_side": "right",
13
+ "tokenizer_class": "QWenTokenizer",
14
+ "tokenizer_file": null
15
+ }