Zoyd commited on
Commit
d86f3db
1 Parent(s): 6cc3dc5

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ - zh
6
+ metrics:
7
+ - accuracy
8
+ library_name: transformers
9
+ pipeline_tag: question-answering
10
+ tags:
11
+ - llm
12
+ - nanbeige
13
+ ---
14
+ <div align="center">
15
+ <h1>
16
+ Nanbeige2-16B-Chat
17
+ </h1>
18
+ </div>
19
+
20
+
21
+ ## <span id="Introduction">模型介绍(Introduction)</span>
22
+
23
+ Nanbeige2-16B-Chat是南北阁实验室最新研发的160亿参数模型,在预训练中使用4.5T Tokens高质量语料。
24
+ 在对齐阶段,我们首先使用了100万条样本进行SFT训练,然后用40万高质量且难度较高的样本进行课程学习,再通过人类反馈DPO,得到Nanbeige2-16B-Chat。Nanbeige2-16B-Chat在各个权威测评数据集上都取得了较优的效果。
25
+
26
+
27
+ The Nanbeige2-16B-Chat is the latest 16B model developed by the Nanbeige Lab, which utilized 4.5T tokens of high-quality training data during the training phase.
28
+ During the alignment phase, we initially trained our model using 1 million samples through Supervised Fine-Tuning (SFT). We then engaged in curriculum learning with 400,000 high-quality samples that presented a greater level of difficulty. Subsequently, we incorporated human feedback through the Dynamic Policy Optimization (DPO), culminating in the development of Nanbeige2-16B-Chat. Nanbeige2-16B-Chat has achieved superior performance across various authoritative benchmark datasets.
29
+
30
+
31
+ ## <span id="Inference">模型推理(Inference)</span>
32
+
33
+ ```
34
+ from transformers import AutoModelForCausalLM, AutoTokenizer
35
+ tokenizer = AutoTokenizer.from_pretrained(
36
+ 'Nanbeige/Nanbeige2-16B-Chat',
37
+ use_fast=False,
38
+ trust_remote_code=True
39
+ )
40
+ model = AutoModelForCausalLM.from_pretrained(
41
+ 'Nanbeige/Nanbeige2-16B-Chat',
42
+ torch_dtype='auto',
43
+ device_map='auto',
44
+ trust_remote_code=True
45
+ )
46
+ messages = [
47
+ {'role': 'user', 'content': 'Hello'}
48
+ ]
49
+ prompt = tokenizer.apply_chat_template(
50
+ messages,
51
+ add_generation_prompt=True,
52
+ tokenize=False
53
+ )
54
+ input_ids = tokenizer(prompt, add_special_tokens=False, return_tensors='pt').input_ids
55
+ output_ids = model.generate(input_ids.to('cuda'))
56
+ resp = tokenizer.decode(output_ids[0][len(input_ids[0]):], skip_special_tokens=True)
57
+ print(resp)
58
+ ```
59
+
60
+ ## <span id="Limitations">局限性(Limitations)</span>
61
+
62
+ 虽然我们在训练过程中非常注重模型的安全性,力求确保其输出符合伦理和法律要求的文本,但由于模型大小和概率生成范式的限制,无法完全避免产生各种不符合预期的输出情况。这些输出可能包含偏见、歧视等有害内容,请勿传播这些内容。我们不承担因传播不良信息而导致的任何后果。
63
+
64
+ While we place great emphasis on the safety of the model during the training process, striving to ensure that its outputs align with ethical and legal requirements, it may not completely avoid generating unexpected outputs due to the model's size and probabilistic nature. These outputs may include harmful content such as bias or discrimination. Please don't propagate such content. We do not assume any responsibility for the consequences resulting from the dissemination of inappropriate information.
65
+
66
+ ## <span id="License">协议(License)</span>
67
+
68
+ 使用 Nanbeige 模型时,您必须遵守 Apache 2.0 许可证和[《南北阁大语言模型许可协议》](https://huggingface.co/Nanbeige/Nanbeige-16B-Base-32k/resolve/main/南北阁大语言模型许可协议.pdf)。如果您打算将 Nanbeige 模型或其衍生产品用于商业目的,请通过以下联系邮箱 [email protected] 提交申请材料,以满足《南北阁大语言模型许可协议》的要求。经过审核后,我们将授予您非排他性、全球范围内、不可转让、不可再许可、可撤销的商业版权许可。
69
+
70
+ When using the Nanbeige models, you must comply with the Apache 2.0 License and the [License Agreement for Large Language Models Nanbeige](https://huggingface.co/Nanbeige/Nanbeige-16B-Base-32k/resolve/main/License_Agreement_for_Large_Language_Models_Nanbeige.pdf). If you intend to use the Nanbeige Models or its derivatives for commercial purposes, please submit application materials to meet the requirements of the Nanbeige Models Community License Agreement by contacting [email protected]. After review, We will grant you a non-exclusive, worldwide, non-transferable, non-sublicensable and revocable commercial copyright license.
added_tokens.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "<|endoftext|>": 58980,
3
+ "<|im_end|>": 58979,
4
+ "<|im_start|>": 58978
5
+ }
config.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "NanbeigeForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_nanbeige.NanbeigeConfig",
7
+ "AutoModelForCausalLM": "modeling_nanbeige.NanbeigeForCausalLM"
8
+ },
9
+ "attention_bias": false,
10
+ "attention_dropout": 0.0,
11
+ "bos_token_id": 1,
12
+ "eos_token_id": 58979,
13
+ "hidden_act": "silu",
14
+ "hidden_size": 5120,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": 13824,
17
+ "max_length": 4096,
18
+ "max_position_embeddings": 4096,
19
+ "model_type": "nanbeige",
20
+ "num_attention_heads": 40,
21
+ "num_hidden_layers": 48,
22
+ "num_key_value_heads": 40,
23
+ "pad_token_id": 0,
24
+ "pretraining_tp": 1,
25
+ "rms_norm_eps": 1e-05,
26
+ "rope_scaling": null,
27
+ "rope_theta": 10000.0,
28
+ "tie_word_embeddings": false,
29
+ "torch_dtype": "bfloat16",
30
+ "transformers_version": "4.38.0",
31
+ "use_cache": true,
32
+ "vocab_size": 59392,
33
+ "quantization_config": {
34
+ "quant_method": "exl2",
35
+ "version": "0.0.21",
36
+ "bits": 6.0,
37
+ "head_bits": 8,
38
+ "calibration": {
39
+ "rows": 100,
40
+ "length": 2048,
41
+ "dataset": "(default)"
42
+ }
43
+ }
44
+ }
configuration_nanbeige.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Nanbeige LLM Lab All Rights Reserved.
2
+
3
+ """ Nanbeige model configuration"""
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+ from transformers.utils import logging
7
+
8
+
9
+ logger = logging.get_logger(__name__)
10
+
11
+ NANBEIGE_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
12
+
13
+
14
+ class NanbeigeConfig(PretrainedConfig):
15
+ model_type = "nanbeige"
16
+
17
+ def __init__(
18
+ self,
19
+ vocab_size=32000,
20
+ hidden_size=4096,
21
+ intermediate_size=11008,
22
+ num_hidden_layers=32,
23
+ num_attention_heads=32,
24
+ hidden_act="silu",
25
+ max_position_embeddings=2048,
26
+ initializer_range=0.02,
27
+ rms_norm_eps=1e-6,
28
+ use_cache=True,
29
+ pad_token_id=0,
30
+ bos_token_id=1,
31
+ eos_token_id=2,
32
+ tie_word_embeddings=False,
33
+ yarn_scale=1.,
34
+ **kwargs,
35
+ ):
36
+ self.vocab_size = vocab_size
37
+ self.max_position_embeddings = max_position_embeddings
38
+ self.hidden_size = hidden_size
39
+ self.intermediate_size = intermediate_size
40
+ self.num_hidden_layers = num_hidden_layers
41
+ self.num_attention_heads = num_attention_heads
42
+ self.hidden_act = hidden_act
43
+ self.initializer_range = initializer_range
44
+ self.rms_norm_eps = rms_norm_eps
45
+ self.use_cache = use_cache
46
+ self.yarn_scale = yarn_scale
47
+ super().__init__(
48
+ pad_token_id=pad_token_id,
49
+ bos_token_id=bos_token_id,
50
+ eos_token_id=eos_token_id,
51
+ tie_word_embeddings=tie_word_embeddings,
52
+ **kwargs,
53
+ )
generation_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 58979,
5
+ "max_length": 4096,
6
+ "pad_token_id": 0,
7
+ "do_sample": true,
8
+ "temperature": 0.3,
9
+ "top_p": 0.9,
10
+ "transformers_version": "4.38.0"
11
+ }
model.safetensors.index.json ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 31667988480
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "model-00007-of-00007.safetensors",
7
+ "model.embed_tokens.weight": "model-00001-of-00007.safetensors",
8
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00007.safetensors",
9
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
10
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00007.safetensors",
11
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00007.safetensors",
12
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00007.safetensors",
13
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00007.safetensors",
14
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00007.safetensors",
15
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00007.safetensors",
16
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00007.safetensors",
17
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00007.safetensors",
18
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
19
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00007.safetensors",
20
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00007.safetensors",
21
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00007.safetensors",
22
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00007.safetensors",
23
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00007.safetensors",
24
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00007.safetensors",
25
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00007.safetensors",
26
+ "model.layers.10.input_layernorm.weight": "model-00002-of-00007.safetensors",
27
+ "model.layers.10.mlp.down_proj.weight": "model-00002-of-00007.safetensors",
28
+ "model.layers.10.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
29
+ "model.layers.10.mlp.up_proj.weight": "model-00002-of-00007.safetensors",
30
+ "model.layers.10.post_attention_layernorm.weight": "model-00002-of-00007.safetensors",
31
+ "model.layers.10.self_attn.k_proj.weight": "model-00002-of-00007.safetensors",
32
+ "model.layers.10.self_attn.o_proj.weight": "model-00002-of-00007.safetensors",
33
+ "model.layers.10.self_attn.q_proj.weight": "model-00002-of-00007.safetensors",
34
+ "model.layers.10.self_attn.v_proj.weight": "model-00002-of-00007.safetensors",
35
+ "model.layers.11.input_layernorm.weight": "model-00002-of-00007.safetensors",
36
+ "model.layers.11.mlp.down_proj.weight": "model-00002-of-00007.safetensors",
37
+ "model.layers.11.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
38
+ "model.layers.11.mlp.up_proj.weight": "model-00002-of-00007.safetensors",
39
+ "model.layers.11.post_attention_layernorm.weight": "model-00002-of-00007.safetensors",
40
+ "model.layers.11.self_attn.k_proj.weight": "model-00002-of-00007.safetensors",
41
+ "model.layers.11.self_attn.o_proj.weight": "model-00002-of-00007.safetensors",
42
+ "model.layers.11.self_attn.q_proj.weight": "model-00002-of-00007.safetensors",
43
+ "model.layers.11.self_attn.v_proj.weight": "model-00002-of-00007.safetensors",
44
+ "model.layers.12.input_layernorm.weight": "model-00002-of-00007.safetensors",
45
+ "model.layers.12.mlp.down_proj.weight": "model-00002-of-00007.safetensors",
46
+ "model.layers.12.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
47
+ "model.layers.12.mlp.up_proj.weight": "model-00002-of-00007.safetensors",
48
+ "model.layers.12.post_attention_layernorm.weight": "model-00002-of-00007.safetensors",
49
+ "model.layers.12.self_attn.k_proj.weight": "model-00002-of-00007.safetensors",
50
+ "model.layers.12.self_attn.o_proj.weight": "model-00002-of-00007.safetensors",
51
+ "model.layers.12.self_attn.q_proj.weight": "model-00002-of-00007.safetensors",
52
+ "model.layers.12.self_attn.v_proj.weight": "model-00002-of-00007.safetensors",
53
+ "model.layers.13.input_layernorm.weight": "model-00002-of-00007.safetensors",
54
+ "model.layers.13.mlp.down_proj.weight": "model-00002-of-00007.safetensors",
55
+ "model.layers.13.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
56
+ "model.layers.13.mlp.up_proj.weight": "model-00002-of-00007.safetensors",
57
+ "model.layers.13.post_attention_layernorm.weight": "model-00002-of-00007.safetensors",
58
+ "model.layers.13.self_attn.k_proj.weight": "model-00002-of-00007.safetensors",
59
+ "model.layers.13.self_attn.o_proj.weight": "model-00002-of-00007.safetensors",
60
+ "model.layers.13.self_attn.q_proj.weight": "model-00002-of-00007.safetensors",
61
+ "model.layers.13.self_attn.v_proj.weight": "model-00002-of-00007.safetensors",
62
+ "model.layers.14.input_layernorm.weight": "model-00003-of-00007.safetensors",
63
+ "model.layers.14.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
64
+ "model.layers.14.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
65
+ "model.layers.14.mlp.up_proj.weight": "model-00003-of-00007.safetensors",
66
+ "model.layers.14.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
67
+ "model.layers.14.self_attn.k_proj.weight": "model-00002-of-00007.safetensors",
68
+ "model.layers.14.self_attn.o_proj.weight": "model-00002-of-00007.safetensors",
69
+ "model.layers.14.self_attn.q_proj.weight": "model-00002-of-00007.safetensors",
70
+ "model.layers.14.self_attn.v_proj.weight": "model-00002-of-00007.safetensors",
71
+ "model.layers.15.input_layernorm.weight": "model-00003-of-00007.safetensors",
72
+ "model.layers.15.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
73
+ "model.layers.15.mlp.gate_proj.weight": "model-00003-of-00007.safetensors",
74
+ "model.layers.15.mlp.up_proj.weight": "model-00003-of-00007.safetensors",
75
+ "model.layers.15.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
76
+ "model.layers.15.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
77
+ "model.layers.15.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
78
+ "model.layers.15.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
79
+ "model.layers.15.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
80
+ "model.layers.16.input_layernorm.weight": "model-00003-of-00007.safetensors",
81
+ "model.layers.16.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
82
+ "model.layers.16.mlp.gate_proj.weight": "model-00003-of-00007.safetensors",
83
+ "model.layers.16.mlp.up_proj.weight": "model-00003-of-00007.safetensors",
84
+ "model.layers.16.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
85
+ "model.layers.16.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
86
+ "model.layers.16.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
87
+ "model.layers.16.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
88
+ "model.layers.16.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
89
+ "model.layers.17.input_layernorm.weight": "model-00003-of-00007.safetensors",
90
+ "model.layers.17.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
91
+ "model.layers.17.mlp.gate_proj.weight": "model-00003-of-00007.safetensors",
92
+ "model.layers.17.mlp.up_proj.weight": "model-00003-of-00007.safetensors",
93
+ "model.layers.17.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
94
+ "model.layers.17.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
95
+ "model.layers.17.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
96
+ "model.layers.17.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
97
+ "model.layers.17.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
98
+ "model.layers.18.input_layernorm.weight": "model-00003-of-00007.safetensors",
99
+ "model.layers.18.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
100
+ "model.layers.18.mlp.gate_proj.weight": "model-00003-of-00007.safetensors",
101
+ "model.layers.18.mlp.up_proj.weight": "model-00003-of-00007.safetensors",
102
+ "model.layers.18.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
103
+ "model.layers.18.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
104
+ "model.layers.18.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
105
+ "model.layers.18.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
106
+ "model.layers.18.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
107
+ "model.layers.19.input_layernorm.weight": "model-00003-of-00007.safetensors",
108
+ "model.layers.19.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
109
+ "model.layers.19.mlp.gate_proj.weight": "model-00003-of-00007.safetensors",
110
+ "model.layers.19.mlp.up_proj.weight": "model-00003-of-00007.safetensors",
111
+ "model.layers.19.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
112
+ "model.layers.19.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
113
+ "model.layers.19.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
114
+ "model.layers.19.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
115
+ "model.layers.19.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
116
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00007.safetensors",
117
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
118
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00007.safetensors",
119
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00007.safetensors",
120
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00007.safetensors",
121
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00007.safetensors",
122
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00007.safetensors",
123
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00007.safetensors",
124
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00007.safetensors",
125
+ "model.layers.20.input_layernorm.weight": "model-00003-of-00007.safetensors",
126
+ "model.layers.20.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
127
+ "model.layers.20.mlp.gate_proj.weight": "model-00003-of-00007.safetensors",
128
+ "model.layers.20.mlp.up_proj.weight": "model-00003-of-00007.safetensors",
129
+ "model.layers.20.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
130
+ "model.layers.20.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
131
+ "model.layers.20.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
132
+ "model.layers.20.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
133
+ "model.layers.20.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
134
+ "model.layers.21.input_layernorm.weight": "model-00003-of-00007.safetensors",
135
+ "model.layers.21.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
136
+ "model.layers.21.mlp.gate_proj.weight": "model-00003-of-00007.safetensors",
137
+ "model.layers.21.mlp.up_proj.weight": "model-00003-of-00007.safetensors",
138
+ "model.layers.21.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
139
+ "model.layers.21.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
140
+ "model.layers.21.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
141
+ "model.layers.21.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
142
+ "model.layers.21.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
143
+ "model.layers.22.input_layernorm.weight": "model-00004-of-00007.safetensors",
144
+ "model.layers.22.mlp.down_proj.weight": "model-00004-of-00007.safetensors",
145
+ "model.layers.22.mlp.gate_proj.weight": "model-00004-of-00007.safetensors",
146
+ "model.layers.22.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
147
+ "model.layers.22.post_attention_layernorm.weight": "model-00004-of-00007.safetensors",
148
+ "model.layers.22.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
149
+ "model.layers.22.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
150
+ "model.layers.22.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
151
+ "model.layers.22.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
152
+ "model.layers.23.input_layernorm.weight": "model-00004-of-00007.safetensors",
153
+ "model.layers.23.mlp.down_proj.weight": "model-00004-of-00007.safetensors",
154
+ "model.layers.23.mlp.gate_proj.weight": "model-00004-of-00007.safetensors",
155
+ "model.layers.23.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
156
+ "model.layers.23.post_attention_layernorm.weight": "model-00004-of-00007.safetensors",
157
+ "model.layers.23.self_attn.k_proj.weight": "model-00004-of-00007.safetensors",
158
+ "model.layers.23.self_attn.o_proj.weight": "model-00004-of-00007.safetensors",
159
+ "model.layers.23.self_attn.q_proj.weight": "model-00004-of-00007.safetensors",
160
+ "model.layers.23.self_attn.v_proj.weight": "model-00004-of-00007.safetensors",
161
+ "model.layers.24.input_layernorm.weight": "model-00004-of-00007.safetensors",
162
+ "model.layers.24.mlp.down_proj.weight": "model-00004-of-00007.safetensors",
163
+ "model.layers.24.mlp.gate_proj.weight": "model-00004-of-00007.safetensors",
164
+ "model.layers.24.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
165
+ "model.layers.24.post_attention_layernorm.weight": "model-00004-of-00007.safetensors",
166
+ "model.layers.24.self_attn.k_proj.weight": "model-00004-of-00007.safetensors",
167
+ "model.layers.24.self_attn.o_proj.weight": "model-00004-of-00007.safetensors",
168
+ "model.layers.24.self_attn.q_proj.weight": "model-00004-of-00007.safetensors",
169
+ "model.layers.24.self_attn.v_proj.weight": "model-00004-of-00007.safetensors",
170
+ "model.layers.25.input_layernorm.weight": "model-00004-of-00007.safetensors",
171
+ "model.layers.25.mlp.down_proj.weight": "model-00004-of-00007.safetensors",
172
+ "model.layers.25.mlp.gate_proj.weight": "model-00004-of-00007.safetensors",
173
+ "model.layers.25.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
174
+ "model.layers.25.post_attention_layernorm.weight": "model-00004-of-00007.safetensors",
175
+ "model.layers.25.self_attn.k_proj.weight": "model-00004-of-00007.safetensors",
176
+ "model.layers.25.self_attn.o_proj.weight": "model-00004-of-00007.safetensors",
177
+ "model.layers.25.self_attn.q_proj.weight": "model-00004-of-00007.safetensors",
178
+ "model.layers.25.self_attn.v_proj.weight": "model-00004-of-00007.safetensors",
179
+ "model.layers.26.input_layernorm.weight": "model-00004-of-00007.safetensors",
180
+ "model.layers.26.mlp.down_proj.weight": "model-00004-of-00007.safetensors",
181
+ "model.layers.26.mlp.gate_proj.weight": "model-00004-of-00007.safetensors",
182
+ "model.layers.26.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
183
+ "model.layers.26.post_attention_layernorm.weight": "model-00004-of-00007.safetensors",
184
+ "model.layers.26.self_attn.k_proj.weight": "model-00004-of-00007.safetensors",
185
+ "model.layers.26.self_attn.o_proj.weight": "model-00004-of-00007.safetensors",
186
+ "model.layers.26.self_attn.q_proj.weight": "model-00004-of-00007.safetensors",
187
+ "model.layers.26.self_attn.v_proj.weight": "model-00004-of-00007.safetensors",
188
+ "model.layers.27.input_layernorm.weight": "model-00004-of-00007.safetensors",
189
+ "model.layers.27.mlp.down_proj.weight": "model-00004-of-00007.safetensors",
190
+ "model.layers.27.mlp.gate_proj.weight": "model-00004-of-00007.safetensors",
191
+ "model.layers.27.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
192
+ "model.layers.27.post_attention_layernorm.weight": "model-00004-of-00007.safetensors",
193
+ "model.layers.27.self_attn.k_proj.weight": "model-00004-of-00007.safetensors",
194
+ "model.layers.27.self_attn.o_proj.weight": "model-00004-of-00007.safetensors",
195
+ "model.layers.27.self_attn.q_proj.weight": "model-00004-of-00007.safetensors",
196
+ "model.layers.27.self_attn.v_proj.weight": "model-00004-of-00007.safetensors",
197
+ "model.layers.28.input_layernorm.weight": "model-00004-of-00007.safetensors",
198
+ "model.layers.28.mlp.down_proj.weight": "model-00004-of-00007.safetensors",
199
+ "model.layers.28.mlp.gate_proj.weight": "model-00004-of-00007.safetensors",
200
+ "model.layers.28.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
201
+ "model.layers.28.post_attention_layernorm.weight": "model-00004-of-00007.safetensors",
202
+ "model.layers.28.self_attn.k_proj.weight": "model-00004-of-00007.safetensors",
203
+ "model.layers.28.self_attn.o_proj.weight": "model-00004-of-00007.safetensors",
204
+ "model.layers.28.self_attn.q_proj.weight": "model-00004-of-00007.safetensors",
205
+ "model.layers.28.self_attn.v_proj.weight": "model-00004-of-00007.safetensors",
206
+ "model.layers.29.input_layernorm.weight": "model-00004-of-00007.safetensors",
207
+ "model.layers.29.mlp.down_proj.weight": "model-00004-of-00007.safetensors",
208
+ "model.layers.29.mlp.gate_proj.weight": "model-00004-of-00007.safetensors",
209
+ "model.layers.29.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
210
+ "model.layers.29.post_attention_layernorm.weight": "model-00004-of-00007.safetensors",
211
+ "model.layers.29.self_attn.k_proj.weight": "model-00004-of-00007.safetensors",
212
+ "model.layers.29.self_attn.o_proj.weight": "model-00004-of-00007.safetensors",
213
+ "model.layers.29.self_attn.q_proj.weight": "model-00004-of-00007.safetensors",
214
+ "model.layers.29.self_attn.v_proj.weight": "model-00004-of-00007.safetensors",
215
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00007.safetensors",
216
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
217
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00007.safetensors",
218
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00007.safetensors",
219
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00007.safetensors",
220
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00007.safetensors",
221
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00007.safetensors",
222
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00007.safetensors",
223
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00007.safetensors",
224
+ "model.layers.30.input_layernorm.weight": "model-00005-of-00007.safetensors",
225
+ "model.layers.30.mlp.down_proj.weight": "model-00005-of-00007.safetensors",
226
+ "model.layers.30.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
227
+ "model.layers.30.mlp.up_proj.weight": "model-00005-of-00007.safetensors",
228
+ "model.layers.30.post_attention_layernorm.weight": "model-00005-of-00007.safetensors",
229
+ "model.layers.30.self_attn.k_proj.weight": "model-00004-of-00007.safetensors",
230
+ "model.layers.30.self_attn.o_proj.weight": "model-00005-of-00007.safetensors",
231
+ "model.layers.30.self_attn.q_proj.weight": "model-00004-of-00007.safetensors",
232
+ "model.layers.30.self_attn.v_proj.weight": "model-00005-of-00007.safetensors",
233
+ "model.layers.31.input_layernorm.weight": "model-00005-of-00007.safetensors",
234
+ "model.layers.31.mlp.down_proj.weight": "model-00005-of-00007.safetensors",
235
+ "model.layers.31.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
236
+ "model.layers.31.mlp.up_proj.weight": "model-00005-of-00007.safetensors",
237
+ "model.layers.31.post_attention_layernorm.weight": "model-00005-of-00007.safetensors",
238
+ "model.layers.31.self_attn.k_proj.weight": "model-00005-of-00007.safetensors",
239
+ "model.layers.31.self_attn.o_proj.weight": "model-00005-of-00007.safetensors",
240
+ "model.layers.31.self_attn.q_proj.weight": "model-00005-of-00007.safetensors",
241
+ "model.layers.31.self_attn.v_proj.weight": "model-00005-of-00007.safetensors",
242
+ "model.layers.32.input_layernorm.weight": "model-00005-of-00007.safetensors",
243
+ "model.layers.32.mlp.down_proj.weight": "model-00005-of-00007.safetensors",
244
+ "model.layers.32.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
245
+ "model.layers.32.mlp.up_proj.weight": "model-00005-of-00007.safetensors",
246
+ "model.layers.32.post_attention_layernorm.weight": "model-00005-of-00007.safetensors",
247
+ "model.layers.32.self_attn.k_proj.weight": "model-00005-of-00007.safetensors",
248
+ "model.layers.32.self_attn.o_proj.weight": "model-00005-of-00007.safetensors",
249
+ "model.layers.32.self_attn.q_proj.weight": "model-00005-of-00007.safetensors",
250
+ "model.layers.32.self_attn.v_proj.weight": "model-00005-of-00007.safetensors",
251
+ "model.layers.33.input_layernorm.weight": "model-00005-of-00007.safetensors",
252
+ "model.layers.33.mlp.down_proj.weight": "model-00005-of-00007.safetensors",
253
+ "model.layers.33.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
254
+ "model.layers.33.mlp.up_proj.weight": "model-00005-of-00007.safetensors",
255
+ "model.layers.33.post_attention_layernorm.weight": "model-00005-of-00007.safetensors",
256
+ "model.layers.33.self_attn.k_proj.weight": "model-00005-of-00007.safetensors",
257
+ "model.layers.33.self_attn.o_proj.weight": "model-00005-of-00007.safetensors",
258
+ "model.layers.33.self_attn.q_proj.weight": "model-00005-of-00007.safetensors",
259
+ "model.layers.33.self_attn.v_proj.weight": "model-00005-of-00007.safetensors",
260
+ "model.layers.34.input_layernorm.weight": "model-00005-of-00007.safetensors",
261
+ "model.layers.34.mlp.down_proj.weight": "model-00005-of-00007.safetensors",
262
+ "model.layers.34.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
263
+ "model.layers.34.mlp.up_proj.weight": "model-00005-of-00007.safetensors",
264
+ "model.layers.34.post_attention_layernorm.weight": "model-00005-of-00007.safetensors",
265
+ "model.layers.34.self_attn.k_proj.weight": "model-00005-of-00007.safetensors",
266
+ "model.layers.34.self_attn.o_proj.weight": "model-00005-of-00007.safetensors",
267
+ "model.layers.34.self_attn.q_proj.weight": "model-00005-of-00007.safetensors",
268
+ "model.layers.34.self_attn.v_proj.weight": "model-00005-of-00007.safetensors",
269
+ "model.layers.35.input_layernorm.weight": "model-00005-of-00007.safetensors",
270
+ "model.layers.35.mlp.down_proj.weight": "model-00005-of-00007.safetensors",
271
+ "model.layers.35.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
272
+ "model.layers.35.mlp.up_proj.weight": "model-00005-of-00007.safetensors",
273
+ "model.layers.35.post_attention_layernorm.weight": "model-00005-of-00007.safetensors",
274
+ "model.layers.35.self_attn.k_proj.weight": "model-00005-of-00007.safetensors",
275
+ "model.layers.35.self_attn.o_proj.weight": "model-00005-of-00007.safetensors",
276
+ "model.layers.35.self_attn.q_proj.weight": "model-00005-of-00007.safetensors",
277
+ "model.layers.35.self_attn.v_proj.weight": "model-00005-of-00007.safetensors",
278
+ "model.layers.36.input_layernorm.weight": "model-00005-of-00007.safetensors",
279
+ "model.layers.36.mlp.down_proj.weight": "model-00005-of-00007.safetensors",
280
+ "model.layers.36.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
281
+ "model.layers.36.mlp.up_proj.weight": "model-00005-of-00007.safetensors",
282
+ "model.layers.36.post_attention_layernorm.weight": "model-00005-of-00007.safetensors",
283
+ "model.layers.36.self_attn.k_proj.weight": "model-00005-of-00007.safetensors",
284
+ "model.layers.36.self_attn.o_proj.weight": "model-00005-of-00007.safetensors",
285
+ "model.layers.36.self_attn.q_proj.weight": "model-00005-of-00007.safetensors",
286
+ "model.layers.36.self_attn.v_proj.weight": "model-00005-of-00007.safetensors",
287
+ "model.layers.37.input_layernorm.weight": "model-00005-of-00007.safetensors",
288
+ "model.layers.37.mlp.down_proj.weight": "model-00005-of-00007.safetensors",
289
+ "model.layers.37.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
290
+ "model.layers.37.mlp.up_proj.weight": "model-00005-of-00007.safetensors",
291
+ "model.layers.37.post_attention_layernorm.weight": "model-00005-of-00007.safetensors",
292
+ "model.layers.37.self_attn.k_proj.weight": "model-00005-of-00007.safetensors",
293
+ "model.layers.37.self_attn.o_proj.weight": "model-00005-of-00007.safetensors",
294
+ "model.layers.37.self_attn.q_proj.weight": "model-00005-of-00007.safetensors",
295
+ "model.layers.37.self_attn.v_proj.weight": "model-00005-of-00007.safetensors",
296
+ "model.layers.38.input_layernorm.weight": "model-00006-of-00007.safetensors",
297
+ "model.layers.38.mlp.down_proj.weight": "model-00006-of-00007.safetensors",
298
+ "model.layers.38.mlp.gate_proj.weight": "model-00006-of-00007.safetensors",
299
+ "model.layers.38.mlp.up_proj.weight": "model-00006-of-00007.safetensors",
300
+ "model.layers.38.post_attention_layernorm.weight": "model-00006-of-00007.safetensors",
301
+ "model.layers.38.self_attn.k_proj.weight": "model-00006-of-00007.safetensors",
302
+ "model.layers.38.self_attn.o_proj.weight": "model-00006-of-00007.safetensors",
303
+ "model.layers.38.self_attn.q_proj.weight": "model-00006-of-00007.safetensors",
304
+ "model.layers.38.self_attn.v_proj.weight": "model-00006-of-00007.safetensors",
305
+ "model.layers.39.input_layernorm.weight": "model-00006-of-00007.safetensors",
306
+ "model.layers.39.mlp.down_proj.weight": "model-00006-of-00007.safetensors",
307
+ "model.layers.39.mlp.gate_proj.weight": "model-00006-of-00007.safetensors",
308
+ "model.layers.39.mlp.up_proj.weight": "model-00006-of-00007.safetensors",
309
+ "model.layers.39.post_attention_layernorm.weight": "model-00006-of-00007.safetensors",
310
+ "model.layers.39.self_attn.k_proj.weight": "model-00006-of-00007.safetensors",
311
+ "model.layers.39.self_attn.o_proj.weight": "model-00006-of-00007.safetensors",
312
+ "model.layers.39.self_attn.q_proj.weight": "model-00006-of-00007.safetensors",
313
+ "model.layers.39.self_attn.v_proj.weight": "model-00006-of-00007.safetensors",
314
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00007.safetensors",
315
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
316
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00007.safetensors",
317
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00007.safetensors",
318
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00007.safetensors",
319
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00007.safetensors",
320
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00007.safetensors",
321
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00007.safetensors",
322
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00007.safetensors",
323
+ "model.layers.40.input_layernorm.weight": "model-00006-of-00007.safetensors",
324
+ "model.layers.40.mlp.down_proj.weight": "model-00006-of-00007.safetensors",
325
+ "model.layers.40.mlp.gate_proj.weight": "model-00006-of-00007.safetensors",
326
+ "model.layers.40.mlp.up_proj.weight": "model-00006-of-00007.safetensors",
327
+ "model.layers.40.post_attention_layernorm.weight": "model-00006-of-00007.safetensors",
328
+ "model.layers.40.self_attn.k_proj.weight": "model-00006-of-00007.safetensors",
329
+ "model.layers.40.self_attn.o_proj.weight": "model-00006-of-00007.safetensors",
330
+ "model.layers.40.self_attn.q_proj.weight": "model-00006-of-00007.safetensors",
331
+ "model.layers.40.self_attn.v_proj.weight": "model-00006-of-00007.safetensors",
332
+ "model.layers.41.input_layernorm.weight": "model-00006-of-00007.safetensors",
333
+ "model.layers.41.mlp.down_proj.weight": "model-00006-of-00007.safetensors",
334
+ "model.layers.41.mlp.gate_proj.weight": "model-00006-of-00007.safetensors",
335
+ "model.layers.41.mlp.up_proj.weight": "model-00006-of-00007.safetensors",
336
+ "model.layers.41.post_attention_layernorm.weight": "model-00006-of-00007.safetensors",
337
+ "model.layers.41.self_attn.k_proj.weight": "model-00006-of-00007.safetensors",
338
+ "model.layers.41.self_attn.o_proj.weight": "model-00006-of-00007.safetensors",
339
+ "model.layers.41.self_attn.q_proj.weight": "model-00006-of-00007.safetensors",
340
+ "model.layers.41.self_attn.v_proj.weight": "model-00006-of-00007.safetensors",
341
+ "model.layers.42.input_layernorm.weight": "model-00006-of-00007.safetensors",
342
+ "model.layers.42.mlp.down_proj.weight": "model-00006-of-00007.safetensors",
343
+ "model.layers.42.mlp.gate_proj.weight": "model-00006-of-00007.safetensors",
344
+ "model.layers.42.mlp.up_proj.weight": "model-00006-of-00007.safetensors",
345
+ "model.layers.42.post_attention_layernorm.weight": "model-00006-of-00007.safetensors",
346
+ "model.layers.42.self_attn.k_proj.weight": "model-00006-of-00007.safetensors",
347
+ "model.layers.42.self_attn.o_proj.weight": "model-00006-of-00007.safetensors",
348
+ "model.layers.42.self_attn.q_proj.weight": "model-00006-of-00007.safetensors",
349
+ "model.layers.42.self_attn.v_proj.weight": "model-00006-of-00007.safetensors",
350
+ "model.layers.43.input_layernorm.weight": "model-00006-of-00007.safetensors",
351
+ "model.layers.43.mlp.down_proj.weight": "model-00006-of-00007.safetensors",
352
+ "model.layers.43.mlp.gate_proj.weight": "model-00006-of-00007.safetensors",
353
+ "model.layers.43.mlp.up_proj.weight": "model-00006-of-00007.safetensors",
354
+ "model.layers.43.post_attention_layernorm.weight": "model-00006-of-00007.safetensors",
355
+ "model.layers.43.self_attn.k_proj.weight": "model-00006-of-00007.safetensors",
356
+ "model.layers.43.self_attn.o_proj.weight": "model-00006-of-00007.safetensors",
357
+ "model.layers.43.self_attn.q_proj.weight": "model-00006-of-00007.safetensors",
358
+ "model.layers.43.self_attn.v_proj.weight": "model-00006-of-00007.safetensors",
359
+ "model.layers.44.input_layernorm.weight": "model-00006-of-00007.safetensors",
360
+ "model.layers.44.mlp.down_proj.weight": "model-00006-of-00007.safetensors",
361
+ "model.layers.44.mlp.gate_proj.weight": "model-00006-of-00007.safetensors",
362
+ "model.layers.44.mlp.up_proj.weight": "model-00006-of-00007.safetensors",
363
+ "model.layers.44.post_attention_layernorm.weight": "model-00006-of-00007.safetensors",
364
+ "model.layers.44.self_attn.k_proj.weight": "model-00006-of-00007.safetensors",
365
+ "model.layers.44.self_attn.o_proj.weight": "model-00006-of-00007.safetensors",
366
+ "model.layers.44.self_attn.q_proj.weight": "model-00006-of-00007.safetensors",
367
+ "model.layers.44.self_attn.v_proj.weight": "model-00006-of-00007.safetensors",
368
+ "model.layers.45.input_layernorm.weight": "model-00007-of-00007.safetensors",
369
+ "model.layers.45.mlp.down_proj.weight": "model-00007-of-00007.safetensors",
370
+ "model.layers.45.mlp.gate_proj.weight": "model-00006-of-00007.safetensors",
371
+ "model.layers.45.mlp.up_proj.weight": "model-00006-of-00007.safetensors",
372
+ "model.layers.45.post_attention_layernorm.weight": "model-00007-of-00007.safetensors",
373
+ "model.layers.45.self_attn.k_proj.weight": "model-00006-of-00007.safetensors",
374
+ "model.layers.45.self_attn.o_proj.weight": "model-00006-of-00007.safetensors",
375
+ "model.layers.45.self_attn.q_proj.weight": "model-00006-of-00007.safetensors",
376
+ "model.layers.45.self_attn.v_proj.weight": "model-00006-of-00007.safetensors",
377
+ "model.layers.46.input_layernorm.weight": "model-00007-of-00007.safetensors",
378
+ "model.layers.46.mlp.down_proj.weight": "model-00007-of-00007.safetensors",
379
+ "model.layers.46.mlp.gate_proj.weight": "model-00007-of-00007.safetensors",
380
+ "model.layers.46.mlp.up_proj.weight": "model-00007-of-00007.safetensors",
381
+ "model.layers.46.post_attention_layernorm.weight": "model-00007-of-00007.safetensors",
382
+ "model.layers.46.self_attn.k_proj.weight": "model-00007-of-00007.safetensors",
383
+ "model.layers.46.self_attn.o_proj.weight": "model-00007-of-00007.safetensors",
384
+ "model.layers.46.self_attn.q_proj.weight": "model-00007-of-00007.safetensors",
385
+ "model.layers.46.self_attn.v_proj.weight": "model-00007-of-00007.safetensors",
386
+ "model.layers.47.input_layernorm.weight": "model-00007-of-00007.safetensors",
387
+ "model.layers.47.mlp.down_proj.weight": "model-00007-of-00007.safetensors",
388
+ "model.layers.47.mlp.gate_proj.weight": "model-00007-of-00007.safetensors",
389
+ "model.layers.47.mlp.up_proj.weight": "model-00007-of-00007.safetensors",
390
+ "model.layers.47.post_attention_layernorm.weight": "model-00007-of-00007.safetensors",
391
+ "model.layers.47.self_attn.k_proj.weight": "model-00007-of-00007.safetensors",
392
+ "model.layers.47.self_attn.o_proj.weight": "model-00007-of-00007.safetensors",
393
+ "model.layers.47.self_attn.q_proj.weight": "model-00007-of-00007.safetensors",
394
+ "model.layers.47.self_attn.v_proj.weight": "model-00007-of-00007.safetensors",
395
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00007.safetensors",
396
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
397
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00007.safetensors",
398
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00007.safetensors",
399
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00007.safetensors",
400
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00007.safetensors",
401
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00007.safetensors",
402
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00007.safetensors",
403
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00007.safetensors",
404
+ "model.layers.6.input_layernorm.weight": "model-00002-of-00007.safetensors",
405
+ "model.layers.6.mlp.down_proj.weight": "model-00002-of-00007.safetensors",
406
+ "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00007.safetensors",
407
+ "model.layers.6.mlp.up_proj.weight": "model-00001-of-00007.safetensors",
408
+ "model.layers.6.post_attention_layernorm.weight": "model-00002-of-00007.safetensors",
409
+ "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00007.safetensors",
410
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00007.safetensors",
411
+ "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00007.safetensors",
412
+ "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00007.safetensors",
413
+ "model.layers.7.input_layernorm.weight": "model-00002-of-00007.safetensors",
414
+ "model.layers.7.mlp.down_proj.weight": "model-00002-of-00007.safetensors",
415
+ "model.layers.7.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
416
+ "model.layers.7.mlp.up_proj.weight": "model-00002-of-00007.safetensors",
417
+ "model.layers.7.post_attention_layernorm.weight": "model-00002-of-00007.safetensors",
418
+ "model.layers.7.self_attn.k_proj.weight": "model-00002-of-00007.safetensors",
419
+ "model.layers.7.self_attn.o_proj.weight": "model-00002-of-00007.safetensors",
420
+ "model.layers.7.self_attn.q_proj.weight": "model-00002-of-00007.safetensors",
421
+ "model.layers.7.self_attn.v_proj.weight": "model-00002-of-00007.safetensors",
422
+ "model.layers.8.input_layernorm.weight": "model-00002-of-00007.safetensors",
423
+ "model.layers.8.mlp.down_proj.weight": "model-00002-of-00007.safetensors",
424
+ "model.layers.8.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
425
+ "model.layers.8.mlp.up_proj.weight": "model-00002-of-00007.safetensors",
426
+ "model.layers.8.post_attention_layernorm.weight": "model-00002-of-00007.safetensors",
427
+ "model.layers.8.self_attn.k_proj.weight": "model-00002-of-00007.safetensors",
428
+ "model.layers.8.self_attn.o_proj.weight": "model-00002-of-00007.safetensors",
429
+ "model.layers.8.self_attn.q_proj.weight": "model-00002-of-00007.safetensors",
430
+ "model.layers.8.self_attn.v_proj.weight": "model-00002-of-00007.safetensors",
431
+ "model.layers.9.input_layernorm.weight": "model-00002-of-00007.safetensors",
432
+ "model.layers.9.mlp.down_proj.weight": "model-00002-of-00007.safetensors",
433
+ "model.layers.9.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
434
+ "model.layers.9.mlp.up_proj.weight": "model-00002-of-00007.safetensors",
435
+ "model.layers.9.post_attention_layernorm.weight": "model-00002-of-00007.safetensors",
436
+ "model.layers.9.self_attn.k_proj.weight": "model-00002-of-00007.safetensors",
437
+ "model.layers.9.self_attn.o_proj.weight": "model-00002-of-00007.safetensors",
438
+ "model.layers.9.self_attn.q_proj.weight": "model-00002-of-00007.safetensors",
439
+ "model.layers.9.self_attn.v_proj.weight": "model-00002-of-00007.safetensors",
440
+ "model.norm.weight": "model-00007-of-00007.safetensors"
441
+ }
442
+ }
modeling_nanbeige.py ADDED
@@ -0,0 +1,935 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import queue
3
+ import threading
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.utils.checkpoint
8
+ from torch import nn
9
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
10
+ from transformers.activations import ACT2FN
11
+ from transformers.generation.streamers import BaseStreamer
12
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, \
13
+ SequenceClassifierOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+
17
+ from .configuration_nanbeige import NanbeigeConfig
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+ try:
22
+ import flash_attn
23
+ from flash_attn import flash_attn_func
24
+
25
+ if int(flash_attn.__version__.split(".")[0]) >= 2 and int(flash_attn.__version__.split(".")[1]) >= 1:
26
+ Version_ = True
27
+ else:
28
+ Version_ = False
29
+ except:
30
+ logger.warn(
31
+ "Warning: import flash_attn fail, please install FlashAttention to get higher efficiency "
32
+ "https://github.com/Dao-AILab/flash-attention"
33
+ )
34
+ Version_ = False
35
+ flash_attn_func = None
36
+
37
+
38
+ def _make_causal_mask(
39
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
40
+ ):
41
+ """
42
+ Make causal mask used for bi-directional self-attention.
43
+ """
44
+ bsz, tgt_len = input_ids_shape
45
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
46
+ mask_cond = torch.arange(mask.size(-1), device=device)
47
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
48
+ mask = mask.to(dtype)
49
+
50
+ if past_key_values_length > 0:
51
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
52
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
53
+
54
+
55
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
56
+ """
57
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
58
+ """
59
+ bsz, src_len = mask.size()
60
+ tgt_len = tgt_len if tgt_len is not None else src_len
61
+
62
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
63
+
64
+ inverted_mask = 1.0 - expanded_mask
65
+
66
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
67
+
68
+
69
+ def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
70
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
71
+
72
+
73
+ def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
74
+ low = math.floor(find_correction_dim(
75
+ low_rot, dim, base, max_position_embeddings))
76
+ high = math.ceil(find_correction_dim(
77
+ high_rot, dim, base, max_position_embeddings))
78
+ return max(low, 0), min(high, dim - 1) # Clamp values just in case
79
+
80
+
81
+ def linear_ramp_mask(min, max, dim):
82
+ if min == max:
83
+ max += 0.001 # Prevent singularity
84
+
85
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
86
+ ramp_func = torch.clamp(linear_func, 0, 1)
87
+ return ramp_func
88
+
89
+
90
+ def get_mscale(scale=1):
91
+ if scale <= 1:
92
+ return 1.0
93
+ return 0.1 * math.log(scale) + 1.0
94
+
95
+
96
+ class YaRNScaledRotaryEmbedding(torch.nn.Module):
97
+ def __init__(self, dim, max_position_embeddings=4096, base=10000, scale=1, original_max_position_embeddings=4096,
98
+ extrapolation_factor=1, attn_factor=1, beta_fast=32, beta_slow=1, finetuned=False, device=None):
99
+ super().__init__()
100
+ self.dim = dim
101
+ self.max_position_embeddings = max_position_embeddings
102
+ self.base = base
103
+ self.scale = scale
104
+ self.original_max_position_embeddings = original_max_position_embeddings
105
+ self.extrapolation_factor = extrapolation_factor
106
+ self.attn_factor = attn_factor
107
+ self.beta_fast = beta_fast
108
+ self.beta_slow = beta_slow
109
+
110
+ self.yarn(device)
111
+
112
+ # Build here to make `torch.jit.trace` work.
113
+ self.max_seq_len_cached = max_position_embeddings
114
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
115
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
116
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
117
+ emb = torch.cat((freqs, freqs), dim=-1)
118
+ dtype = torch.get_default_dtype()
119
+
120
+ self.register_buffer("cos_cached", (emb.cos() * self.mscale)[None, None, :, :].to(dtype), persistent=False)
121
+ self.register_buffer("sin_cached", (emb.sin() * self.mscale)[None, None, :, :].to(dtype), persistent=False)
122
+
123
+ def forward(self, x, seq_len=None):
124
+ # x: [bs, num_attention_heads, seq_len, head_size]
125
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
126
+ if seq_len > self.max_seq_len_cached:
127
+ self.max_seq_len_cached = seq_len
128
+
129
+ t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
130
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
131
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
132
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
133
+
134
+ self.register_buffer("cos_cached", (emb.cos() * self.mscale)[None, None, :, :].to(x.dtype),
135
+ persistent=False)
136
+ self.register_buffer("sin_cached", (emb.sin() * self.mscale)[None, None, :, :].to(x.dtype),
137
+ persistent=False)
138
+ return (
139
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
140
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
141
+ )
142
+
143
+ def yarn(self, device):
144
+ pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
145
+ inv_freq_extrapolation = 1.0 / pos_freqs
146
+ inv_freq_interpolation = 1.0 / (self.scale * pos_freqs)
147
+
148
+ low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base,
149
+ self.original_max_position_embeddings)
150
+ inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to(
151
+ device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
152
+ inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
153
+
154
+ self.register_buffer("inv_freq", inv_freq)
155
+ self.mscale = float(
156
+ get_mscale(self.scale) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
157
+
158
+
159
+ class RMSNorm(nn.Module):
160
+ def __init__(self, hidden_size, eps=1e-6):
161
+ super().__init__()
162
+ self.weight = nn.Parameter(torch.ones(hidden_size))
163
+ self.variance_epsilon = eps
164
+
165
+ def forward(self, hidden_states):
166
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
167
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
168
+
169
+ # convert into half-precision if necessary
170
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
171
+ hidden_states = hidden_states.to(self.weight.dtype)
172
+
173
+ return self.weight * hidden_states
174
+
175
+
176
+ class RotaryEmbedding(torch.nn.Module):
177
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
178
+ super().__init__()
179
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
180
+ self.register_buffer("inv_freq", inv_freq)
181
+ self.max_seq_len_cached = max_position_embeddings
182
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
183
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
184
+ emb = torch.cat((freqs, freqs), dim=-1)
185
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
186
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
187
+
188
+ def forward(self, x, seq_len=None):
189
+ # x: [bs, num_attention_heads, seq_len, head_size]
190
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
191
+ if seq_len > self.max_seq_len_cached:
192
+ self.max_seq_len_cached = seq_len
193
+ t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
194
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
195
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
196
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
197
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
198
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
199
+ return (
200
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
201
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
202
+ )
203
+
204
+
205
+ def rotate_half(x):
206
+ """Rotates half the hidden dims of the input."""
207
+ x1 = x[..., : x.shape[-1] // 2]
208
+ x2 = x[..., x.shape[-1] // 2:]
209
+ return torch.cat((-x2, x1), dim=-1)
210
+
211
+
212
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
213
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
214
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
215
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
216
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
217
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
218
+ q_embed = (q * cos) + (rotate_half(q) * sin)
219
+ k_embed = (k * cos) + (rotate_half(k) * sin)
220
+ return q_embed, k_embed
221
+
222
+
223
+ class NanbeigeMLP(nn.Module):
224
+ def __init__(
225
+ self,
226
+ hidden_size: int,
227
+ intermediate_size: int,
228
+ hidden_act: str,
229
+ ):
230
+ super().__init__()
231
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
232
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
233
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
234
+ self.act_fn = ACT2FN[hidden_act]
235
+
236
+ def forward(self, x):
237
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
238
+
239
+
240
+ class NanbeigeAttention(nn.Module):
241
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
242
+
243
+ def __init__(self, config: NanbeigeConfig):
244
+ super().__init__()
245
+ self.config = config
246
+ self.hidden_size = config.hidden_size
247
+ self.num_heads = config.num_attention_heads
248
+ self.head_dim = self.hidden_size // self.num_heads
249
+ self.max_position_embeddings = config.max_position_embeddings
250
+
251
+ if (self.head_dim * self.num_heads) != self.hidden_size:
252
+ raise ValueError(
253
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
254
+ f" and `num_heads`: {self.num_heads})."
255
+ )
256
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
257
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
258
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
259
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
260
+ if self.config.yarn_scale > 1:
261
+ self.rotary_emb = YaRNScaledRotaryEmbedding(self.head_dim, scale=self.config.yarn_scale,
262
+ original_max_position_embeddings=self.max_position_embeddings)
263
+ else:
264
+ self.rotary_emb = RotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
265
+
266
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
267
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
268
+
269
+ def forward(
270
+ self,
271
+ hidden_states: torch.Tensor,
272
+ attention_mask: Optional[torch.Tensor] = None,
273
+ position_ids: Optional[torch.LongTensor] = None,
274
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
275
+ output_attentions: bool = False,
276
+ use_cache: bool = False,
277
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
278
+ bsz, q_len, _ = hidden_states.size()
279
+
280
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
281
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
282
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
283
+
284
+ kv_seq_len = key_states.shape[-2]
285
+ if past_key_value is not None:
286
+ kv_seq_len += past_key_value[0].shape[-2]
287
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
288
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
289
+
290
+ if past_key_value is not None:
291
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
292
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
293
+
294
+ past_key_value = (key_states, value_states) if use_cache else None
295
+
296
+ if Version_ or (flash_attn_func and query_states.size() == key_states.size()):
297
+ attn_output = flash_attn_func(query_states.transpose(1, 2), key_states.transpose(1, 2),
298
+ value_states.transpose(1, 2), dropout_p=0.0, softmax_scale=None, causal=True)
299
+ else:
300
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
301
+
302
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
303
+ raise ValueError(
304
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
305
+ f" {attn_weights.size()}"
306
+ )
307
+
308
+ if attention_mask is not None:
309
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
310
+ raise ValueError(
311
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
312
+ )
313
+ attn_weights = attn_weights + attention_mask
314
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
315
+
316
+ # upcast attention to fp32
317
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
318
+ attn_output = torch.matmul(attn_weights, value_states)
319
+
320
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
321
+ raise ValueError(
322
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
323
+ f" {attn_output.size()}"
324
+ )
325
+
326
+ attn_output = attn_output.transpose(1, 2)
327
+
328
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
329
+
330
+ attn_output = self.o_proj(attn_output)
331
+
332
+ if not output_attentions:
333
+ attn_weights = None
334
+
335
+ return attn_output, attn_weights, past_key_value
336
+
337
+
338
+ class NanbeigeDecoderLayer(nn.Module):
339
+ def __init__(self, config: NanbeigeConfig):
340
+ super().__init__()
341
+ self.hidden_size = config.hidden_size
342
+ self.self_attn = NanbeigeAttention(config=config)
343
+ self.mlp = NanbeigeMLP(
344
+ hidden_size=self.hidden_size,
345
+ intermediate_size=config.intermediate_size,
346
+ hidden_act=config.hidden_act,
347
+ )
348
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
349
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
350
+
351
+ def forward(
352
+ self,
353
+ hidden_states: torch.Tensor,
354
+ attention_mask: Optional[torch.Tensor] = None,
355
+ position_ids: Optional[torch.LongTensor] = None,
356
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
357
+ output_attentions: Optional[bool] = False,
358
+ use_cache: Optional[bool] = False,
359
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
360
+ """
361
+ Args:
362
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
363
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
364
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
365
+ output_attentions (`bool`, *optional*):
366
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
367
+ returned tensors for more detail.
368
+ use_cache (`bool`, *optional*):
369
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
370
+ (see `past_key_values`).
371
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
372
+ """
373
+
374
+ residual = hidden_states
375
+
376
+ hidden_states = self.input_layernorm(hidden_states)
377
+
378
+ # Self Attention
379
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
380
+ hidden_states=hidden_states,
381
+ attention_mask=attention_mask,
382
+ position_ids=position_ids,
383
+ past_key_value=past_key_value,
384
+ output_attentions=output_attentions,
385
+ use_cache=use_cache,
386
+ )
387
+ hidden_states = residual + hidden_states
388
+
389
+ # Fully Connected
390
+ residual = hidden_states
391
+ hidden_states = self.post_attention_layernorm(hidden_states)
392
+ hidden_states = self.mlp(hidden_states)
393
+ hidden_states = residual + hidden_states
394
+
395
+ outputs = (hidden_states,)
396
+
397
+ if output_attentions:
398
+ outputs += (self_attn_weights,)
399
+
400
+ if use_cache:
401
+ outputs += (present_key_value,)
402
+
403
+ return outputs
404
+
405
+
406
+ class NanbeigePreTrainedModel(PreTrainedModel):
407
+ config_class = NanbeigeConfig
408
+ base_model_prefix = "model"
409
+ supports_gradient_checkpointing = True
410
+ _no_split_modules = ["NanbeigeDecoderLayer"]
411
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
412
+
413
+ def _init_weights(self, module):
414
+ std = self.config.initializer_range
415
+ if isinstance(module, nn.Linear):
416
+ module.weight.data.normal_(mean=0.0, std=std)
417
+ if module.bias is not None:
418
+ module.bias.data.zero_()
419
+ elif isinstance(module, nn.Embedding):
420
+ module.weight.data.normal_(mean=0.0, std=std)
421
+ if module.padding_idx is not None:
422
+ module.weight.data[module.padding_idx].zero_()
423
+
424
+ def _set_gradient_checkpointing(self, module, value=False):
425
+ if isinstance(module, NanbeigeModel):
426
+ module.gradient_checkpointing = value
427
+
428
+
429
+ class NanbeigeModel(NanbeigePreTrainedModel):
430
+ def __init__(self, config: NanbeigeConfig):
431
+ super().__init__(config)
432
+ self.padding_idx = config.pad_token_id
433
+ self.vocab_size = config.vocab_size
434
+
435
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
436
+ self.layers = nn.ModuleList([NanbeigeDecoderLayer(config) for _ in range(config.num_hidden_layers)])
437
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
438
+
439
+ self.gradient_checkpointing = False
440
+ # Initialize weights and apply final processing
441
+ self.post_init()
442
+
443
+ def get_input_embeddings(self):
444
+ return self.embed_tokens
445
+
446
+ def set_input_embeddings(self, value):
447
+ self.embed_tokens = value
448
+
449
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
450
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
451
+ # create causal mask
452
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
453
+ combined_attention_mask = None
454
+ if input_shape[-1] > 1:
455
+ combined_attention_mask = _make_causal_mask(
456
+ input_shape,
457
+ inputs_embeds.dtype,
458
+ device=inputs_embeds.device,
459
+ past_key_values_length=past_key_values_length,
460
+ )
461
+
462
+ if attention_mask is not None:
463
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
464
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
465
+ inputs_embeds.device
466
+ )
467
+ combined_attention_mask = (
468
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
469
+ )
470
+
471
+ return combined_attention_mask
472
+
473
+ def forward(
474
+ self,
475
+ input_ids: torch.LongTensor = None,
476
+ attention_mask: Optional[torch.Tensor] = None,
477
+ position_ids: Optional[torch.LongTensor] = None,
478
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
479
+ inputs_embeds: Optional[torch.FloatTensor] = None,
480
+ use_cache: Optional[bool] = None,
481
+ output_attentions: Optional[bool] = None,
482
+ output_hidden_states: Optional[bool] = None,
483
+ return_dict: Optional[bool] = None,
484
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
485
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
486
+ output_hidden_states = (
487
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
488
+ )
489
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
490
+
491
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
492
+
493
+ # retrieve input_ids and inputs_embeds
494
+ if input_ids is not None and inputs_embeds is not None:
495
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
496
+ elif input_ids is not None:
497
+ batch_size, seq_length = input_ids.shape
498
+ elif inputs_embeds is not None:
499
+ batch_size, seq_length, _ = inputs_embeds.shape
500
+ else:
501
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
502
+
503
+ seq_length_with_past = seq_length
504
+ past_key_values_length = 0
505
+
506
+ if past_key_values is not None:
507
+ past_key_values_length = past_key_values[0][0].shape[2]
508
+ seq_length_with_past = seq_length_with_past + past_key_values_length
509
+ else:
510
+ past_key_values = [None for _ in range(len(self.layers))]
511
+
512
+ if position_ids is None:
513
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
514
+ position_ids = torch.arange(
515
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
516
+ )
517
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
518
+ else:
519
+ position_ids = position_ids.view(-1, seq_length).long()
520
+
521
+ if inputs_embeds is None:
522
+ inputs_embeds = self.embed_tokens(input_ids)
523
+ # embed positions
524
+ if attention_mask is None:
525
+ attention_mask = torch.ones(
526
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
527
+ )
528
+ attention_mask = self._prepare_decoder_attention_mask(
529
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
530
+ )
531
+
532
+ hidden_states = inputs_embeds
533
+
534
+ if self.gradient_checkpointing and self.training:
535
+ if use_cache:
536
+ logger.warning_once(
537
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
538
+ )
539
+ use_cache = False
540
+
541
+ # decoder layers
542
+ all_hidden_states = () if output_hidden_states else None
543
+ all_self_attns = () if output_attentions else None
544
+ next_cache = [] if use_cache else None
545
+
546
+ for idx, decoder_layer in enumerate(self.layers):
547
+ if output_hidden_states:
548
+ all_hidden_states += (hidden_states,)
549
+
550
+ past_key_value = past_key_values.pop(0) if past_key_values is not None else None
551
+
552
+ if self.gradient_checkpointing and self.training:
553
+
554
+ def create_custom_forward(module):
555
+ def custom_forward(*inputs):
556
+ # None for past_key_value
557
+ return module(*inputs, output_attentions, None)
558
+
559
+ return custom_forward
560
+
561
+ layer_outputs = torch.utils.checkpoint.checkpoint(
562
+ create_custom_forward(decoder_layer),
563
+ hidden_states,
564
+ attention_mask,
565
+ position_ids,
566
+ None,
567
+ )
568
+ else:
569
+ layer_outputs = decoder_layer(
570
+ hidden_states,
571
+ attention_mask=attention_mask,
572
+ position_ids=position_ids,
573
+ past_key_value=past_key_value,
574
+ output_attentions=output_attentions,
575
+ use_cache=use_cache,
576
+ )
577
+
578
+ hidden_states = layer_outputs[0]
579
+
580
+ if use_cache:
581
+ next_cache.append(layer_outputs[2 if output_attentions else 1])
582
+
583
+ if output_attentions:
584
+ all_self_attns += (layer_outputs[1],)
585
+
586
+ hidden_states = self.norm(hidden_states)
587
+
588
+ # add hidden states from the last decoder layer
589
+ if output_hidden_states:
590
+ all_hidden_states += (hidden_states,)
591
+
592
+ if not return_dict:
593
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
594
+ return BaseModelOutputWithPast(
595
+ last_hidden_state=hidden_states,
596
+ past_key_values=next_cache,
597
+ hidden_states=all_hidden_states,
598
+ attentions=all_self_attns,
599
+ )
600
+
601
+
602
+ class NanbeigeForCausalLM(NanbeigePreTrainedModel):
603
+ def __init__(self, config):
604
+ super().__init__(config)
605
+ self.model = NanbeigeModel(config)
606
+
607
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
608
+
609
+ # Initialize weights and apply final processing
610
+ self.post_init()
611
+
612
+ def get_input_embeddings(self):
613
+ return self.model.embed_tokens
614
+
615
+ def set_input_embeddings(self, value):
616
+ self.model.embed_tokens = value
617
+
618
+ def get_output_embeddings(self):
619
+ return self.lm_head
620
+
621
+ def set_output_embeddings(self, new_embeddings):
622
+ self.lm_head = new_embeddings
623
+
624
+ def set_decoder(self, decoder):
625
+ self.model = decoder
626
+
627
+ def get_decoder(self):
628
+ return self.model
629
+
630
+ def forward(
631
+ self,
632
+ input_ids: torch.LongTensor = None,
633
+ attention_mask: Optional[torch.Tensor] = None,
634
+ position_ids: Optional[torch.LongTensor] = None,
635
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
636
+ inputs_embeds: Optional[torch.FloatTensor] = None,
637
+ labels: Optional[torch.LongTensor] = None,
638
+ use_cache: Optional[bool] = None,
639
+ output_attentions: Optional[bool] = None,
640
+ output_hidden_states: Optional[bool] = None,
641
+ return_dict: Optional[bool] = None,
642
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
643
+
644
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
645
+ output_hidden_states = (
646
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
647
+ )
648
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
649
+
650
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
651
+ outputs = self.model(
652
+ input_ids=input_ids,
653
+ attention_mask=attention_mask,
654
+ position_ids=position_ids,
655
+ past_key_values=past_key_values,
656
+ inputs_embeds=inputs_embeds,
657
+ use_cache=use_cache,
658
+ output_attentions=output_attentions,
659
+ output_hidden_states=output_hidden_states,
660
+ return_dict=return_dict,
661
+ )
662
+
663
+ hidden_states = outputs[0]
664
+ logits = self.lm_head(hidden_states)
665
+
666
+ loss = None
667
+ if labels is not None:
668
+ # Shift so that tokens < n predict n
669
+ shift_logits = logits[..., :-1, :].contiguous()
670
+ shift_labels = labels[..., 1:].contiguous()
671
+ # Flatten the tokens
672
+ loss_fct = CrossEntropyLoss()
673
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
674
+ shift_labels = shift_labels.view(-1)
675
+ # Enable model parallelism
676
+ shift_labels = shift_labels.to(shift_logits.device)
677
+ loss = loss_fct(shift_logits, shift_labels)
678
+
679
+ if not return_dict:
680
+ output = (logits,) + outputs[1:]
681
+ return (loss,) + output if loss is not None else output
682
+
683
+ return CausalLMOutputWithPast(
684
+ loss=loss,
685
+ logits=logits,
686
+ past_key_values=outputs.past_key_values,
687
+ hidden_states=outputs.hidden_states,
688
+ attentions=outputs.attentions,
689
+ )
690
+
691
+ def prepare_inputs_for_generation(
692
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
693
+ ):
694
+ if past_key_values:
695
+ input_ids = input_ids[:, -1:]
696
+
697
+ position_ids = kwargs.get("position_ids", None)
698
+ if attention_mask is not None and position_ids is None:
699
+ # create position_ids on the fly for batch generation
700
+ position_ids = attention_mask.long().cumsum(-1) - 1
701
+ position_ids.masked_fill_(attention_mask == 0, 1)
702
+ if past_key_values:
703
+ position_ids = position_ids[:, -1].unsqueeze(-1)
704
+
705
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
706
+ if inputs_embeds is not None and past_key_values is None:
707
+ model_inputs = {"inputs_embeds": inputs_embeds}
708
+ else:
709
+ model_inputs = {"input_ids": input_ids}
710
+
711
+ model_inputs.update(
712
+ {
713
+ "position_ids": position_ids,
714
+ "past_key_values": past_key_values,
715
+ "use_cache": kwargs.get("use_cache"),
716
+ "attention_mask": attention_mask,
717
+ }
718
+ )
719
+ return model_inputs
720
+
721
+ @staticmethod
722
+ def _reorder_cache(past_key_values, beam_idx):
723
+ reordered_past = ()
724
+ for layer_past in past_key_values:
725
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
726
+ return reordered_past
727
+
728
+ @torch.no_grad()
729
+ def chat(self,
730
+ tokenizer,
731
+ query: str,
732
+ messages: List[dict] = None,
733
+ streamer: Optional[BaseStreamer] = None,
734
+ max_new_tokens: int = 512,
735
+ do_sample: bool = True,
736
+ temperature: float = 0.3,
737
+ top_p: float = 0.9,
738
+ **kwargs):
739
+ if messages is None:
740
+ messages = []
741
+ messages = messages + [{'role': 'user', 'content': query}]
742
+ prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
743
+ inputs = tokenizer(prompt, add_special_tokens=False, return_tensors='pt')
744
+ inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
745
+ outputs = self.generate(**inputs,
746
+ streamer=streamer,
747
+ max_new_tokens=max_new_tokens,
748
+ do_sample=do_sample,
749
+ temperature=temperature,
750
+ top_p=top_p,
751
+ **kwargs)
752
+ outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]):]
753
+ response = tokenizer.decode(outputs, skip_special_tokens=True)
754
+ response = response.split(tokenizer.eos_token)[0]
755
+ messages.append({'role': 'user', 'content': query})
756
+ messages.append({'role': 'assistant', 'content': response})
757
+ return response, messages
758
+
759
+ @torch.no_grad()
760
+ def stream_chat(self,
761
+ tokenizer,
762
+ query: str,
763
+ messages: List[dict] = None,
764
+ max_new_tokens: int = 1024,
765
+ do_sample: bool = True,
766
+ temperature: float = 0.8,
767
+ top_p: float = 0.8,
768
+ **kwargs):
769
+
770
+ response_queue = queue.Queue(maxsize=20)
771
+ if messages is None:
772
+ messages = [{'role': 'system', 'content': NANBEIGE_SYSTEM_PROMPT}]
773
+
774
+ class ChatStreamer(BaseStreamer):
775
+ def __init__(self, tokenizer) -> None:
776
+ super().__init__()
777
+ self.tokenizer = tokenizer
778
+ self.queue = response_queue
779
+ self.query = query
780
+ self.messages = messages
781
+ self.response = ""
782
+ self.received_inputs = False
783
+ self.queue.put((self.response, messages + [{'role': 'user', 'content': self.query},
784
+ {'role': 'assistant', 'content': self.response}]))
785
+
786
+ def put(self, value):
787
+ if len(value.shape) > 1 and value.shape[0] > 1:
788
+ raise ValueError("ChatStreamer only supports batch size 1")
789
+ elif len(value.shape) > 1:
790
+ value = value[0]
791
+
792
+ if not self.received_inputs:
793
+ # The first received value is input_ids, ignore here
794
+ self.received_inputs = True
795
+ return
796
+
797
+ token = self.tokenizer.decode([value[-1]], skip_special_tokens=True)
798
+ if token.strip() != "</s>":
799
+ self.response = self.response + token
800
+ messages = self.messages + [{'role': 'user', 'content': self.query},
801
+ {'role': 'assistant', 'content': self.response}]
802
+ self.queue.put((self.response, messages))
803
+
804
+ def end(self):
805
+ self.queue.put(None)
806
+
807
+ def stream_task():
808
+ return self.chat(
809
+ tokenizer=tokenizer,
810
+ query=query,
811
+ messages=messages,
812
+ streamer=ChatStreamer(tokenizer=tokenizer),
813
+ max_new_tokens=max_new_tokens,
814
+ do_sample=do_sample,
815
+ temperature=temperature,
816
+ top_p=top_p,
817
+ **kwargs
818
+ )
819
+
820
+ def consumer():
821
+ threading.Thread(target=stream_task).start()
822
+ while True:
823
+ res = response_queue.get()
824
+ if res is None:
825
+ return
826
+ yield res
827
+
828
+ return consumer()
829
+
830
+
831
+ class NanbeigeForSequenceClassification(NanbeigePreTrainedModel):
832
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
833
+
834
+ def __init__(self, config):
835
+ super().__init__(config)
836
+ self.num_labels = config.num_labels
837
+ self.model = NanbeigeModel(config)
838
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
839
+
840
+ # Initialize weights and apply final processing
841
+ self.post_init()
842
+
843
+ def get_input_embeddings(self):
844
+ return self.model.embed_tokens
845
+
846
+ def set_input_embeddings(self, value):
847
+ self.model.embed_tokens = value
848
+
849
+ def forward(
850
+ self,
851
+ input_ids: torch.LongTensor = None,
852
+ attention_mask: Optional[torch.Tensor] = None,
853
+ position_ids: Optional[torch.LongTensor] = None,
854
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
855
+ inputs_embeds: Optional[torch.FloatTensor] = None,
856
+ labels: Optional[torch.LongTensor] = None,
857
+ use_cache: Optional[bool] = None,
858
+ output_attentions: Optional[bool] = None,
859
+ output_hidden_states: Optional[bool] = None,
860
+ return_dict: Optional[bool] = None,
861
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
862
+ r"""
863
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
864
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
865
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
866
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
867
+ """
868
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
869
+
870
+ transformer_outputs = self.model(
871
+ input_ids,
872
+ attention_mask=attention_mask,
873
+ position_ids=position_ids,
874
+ past_key_values=past_key_values,
875
+ inputs_embeds=inputs_embeds,
876
+ use_cache=use_cache,
877
+ output_attentions=output_attentions,
878
+ output_hidden_states=output_hidden_states,
879
+ return_dict=return_dict,
880
+ )
881
+ hidden_states = transformer_outputs[0]
882
+ logits = self.score(hidden_states)
883
+
884
+ if input_ids is not None:
885
+ batch_size = input_ids.shape[0]
886
+ else:
887
+ batch_size = inputs_embeds.shape[0]
888
+
889
+ if self.config.pad_token_id is None and batch_size != 1:
890
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
891
+ if self.config.pad_token_id is None:
892
+ sequence_lengths = -1
893
+ else:
894
+ if input_ids is not None:
895
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
896
+ else:
897
+ sequence_lengths = -1
898
+
899
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
900
+
901
+ loss = None
902
+ if labels is not None:
903
+ labels = labels.to(logits.device)
904
+ if self.config.problem_type is None:
905
+ if self.num_labels == 1:
906
+ self.config.problem_type = "regression"
907
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
908
+ self.config.problem_type = "single_label_classification"
909
+ else:
910
+ self.config.problem_type = "multi_label_classification"
911
+
912
+ if self.config.problem_type == "regression":
913
+ loss_fct = MSELoss()
914
+ if self.num_labels == 1:
915
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
916
+ else:
917
+ loss = loss_fct(pooled_logits, labels)
918
+ elif self.config.problem_type == "single_label_classification":
919
+ loss_fct = CrossEntropyLoss()
920
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
921
+ elif self.config.problem_type == "multi_label_classification":
922
+ loss_fct = BCEWithLogitsLoss()
923
+ loss = loss_fct(pooled_logits, labels)
924
+ if not return_dict:
925
+ output = (pooled_logits,) + transformer_outputs[1:]
926
+ return ((loss,) + output) if loss is not None else output
927
+
928
+ return SequenceClassifierOutputWithPast(
929
+ loss=loss,
930
+ logits=pooled_logits,
931
+ past_key_values=transformer_outputs.past_key_values,
932
+ hidden_states=transformer_outputs.hidden_states,
933
+ attentions=transformer_outputs.attentions,
934
+ )
935
+
output-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c5200fb872cb4cc86ef94777feba1f9ffbc3919b4e726e9c3dcdae4065d5ae2
3
+ size 8538475060
output-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f6ddcd85f66fab2f82bd8b7ff2114879c21a9ad048701f2ffe4c1e0310dceb8
3
+ size 3757635664
special_tokens_map.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ {
4
+ "content": "<|endoftext|>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ }
10
+ ],
11
+ "bos_token": "<|im_start|>",
12
+ "eos_token": "<|im_end|>",
13
+ "pad_token": {
14
+ "content": "<unk>",
15
+ "lstrip": false,
16
+ "normalized": true,
17
+ "rstrip": false,
18
+ "single_word": false
19
+ },
20
+ "unk_token": {
21
+ "content": "<unk>",
22
+ "lstrip": false,
23
+ "normalized": true,
24
+ "rstrip": false,
25
+ "single_word": false
26
+ }
27
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ede13db1d0956ec033608741b0fd83d149a5ec54306af70e2ba829242f75b73b
3
+ size 851705
tokenizer_config.json ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": true,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": true,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": true,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "58978": {
30
+ "content": "<|im_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "58979": {
38
+ "content": "<|im_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "58980": {
46
+ "content": "<|endoftext|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ }
53
+ },
54
+ "additional_special_tokens": [
55
+ "<|im_start|>",
56
+ "<|im_end|>",
57
+ "<|endoftext|>"
58
+ ],
59
+ "bos_token": "<|im_start|>",
60
+ "chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\n你是一个名为\"南北阁\"的人工智能助手,正在与人类用户进行交谈。你的目标是以最有帮助和最逻辑的方式回答问题,同时确保内容的安全性。你的回答中不应包含任何有害、政治化、宗教化、不道德、种族主义、非法的内容。请确保你的回答不带有社会偏见,符合社会主义价值观。如果遇到的问题无意义或事实上不连贯,请不要回答错误的内容,而是解释问题为何无效或不连贯。如果你不知道问题的答案,也请勿提供错误的信息。<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'].strip() + '<|im_end|>' + '\n'}}{% endfor %}{% if (add_generation_prompt is defined and add_generation_prompt) %}{{ '<|im_start|>assistant\n' }}{% endif %}",
61
+ "clean_up_tokenization_spaces": false,
62
+ "eos_token": "<|im_end|>",
63
+ "legacy": true,
64
+ "model_max_length": 4096,
65
+ "pad_token": "<unk>",
66
+ "sp_model_kwargs": {},
67
+ "spaces_between_special_tokens": false,
68
+ "tokenizer_class": "LlamaTokenizer",
69
+ "unk_token": "<unk>",
70
+ "use_default_system_prompt": false
71
+ }