liusx14 commited on
Commit
ab88734
1 Parent(s): 822d7e5
README.md CHANGED
@@ -1,3 +1,150 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <h1>
3
+ 星辰语义大模型-TeleChat
4
+ </h1>
5
+ </div>
6
+
7
+ <p align="center">
8
+ 🤗 <a href="https://huggingface.co/Tele-AI" target="_blank">Hugging Face</a> • 🏔 <a href="" target="_blank">MindSpore</a>️ • 🦉 <a href="https://github.com/Tele-AI/Telechat" target="_blank">github</a>️ • 🐾 <a href="https://gitee.com/Tele-AI/tele-chat" target="_blank">gitee</a>️ • 💬 <a href="https://github.com/Tele-AI/Telechat/blob/master/images/wechat.jpg" target="_blank">WeChat</a>
9
+ </p>
10
+
11
+ <p align="center">
12
+ <a href="https://arxiv.org/abs/2401.03804" target="_blank"> Tech Report </a>
13
+ </p>
14
+
15
+
16
+ # 最新动态
17
+ - 2024.5.17 开源12B-v2版本chat模型及量化版本
18
+ - 2024.3.20 开源12B版本chat模型及量化版本
19
+ - 2024.1.11 开源1T中文数据集
20
+ - 2024.1.10 开源7B版本chat模型及其量化版本
21
+
22
+ # 模型介绍
23
+ ### 星辰语义大模型-TeleChat
24
+ - 星辰语义大模型TeleChat是由中电信人工智能科技有限公司研发训练的大语言模型,其中7B模型基座采用1.5万亿 Tokens中英文高质量语料进行训练,12B模型基座采用3万亿 Tokens中英文高质量语料进行训练。
25
+ - 我们开源了对话模型**TeleChat-7B-bot**与**TeleChat-12B-bot**,以及其`huggingface`格式的权重文件。此外,我们还开源了7B、12B模型的int8和int4量化版本。
26
+ - **TeleChat-12B-bot**在模型结构、训练数据、训练方法等方面进行了改进,在通用问答和知识类、代码类、数学类榜单上相比**TeleChat-7B-bot**均有大幅提升。
27
+ - 在模型结构方面,我们使用小规模的模型尝试多种模型结构的组合,选择最优结构。相比**TeleChat-7B-bot**模型,**TeleChat-12B-bot**模型采用了词嵌入层与输出层解耦的结构,将词嵌入层和输出lm head层参数分开,有助于增强训练稳定性和收敛性。
28
+ - 在训练数据方面,我们收集了覆盖书籍、百科、新闻、政务、法律、医药、专利、论文、数学、代码等诸多方面的大量中英文数据;通过优化数据清洗策略大幅提升数据的文本干净度、观点无偏性、内容有效性、格式规范性。
29
+ - 在训练方法方面,我们使用科学数据配比学习与课程学习的方法,使用小参数模型在多种数据配比的数据上拟合,得到对各个数据集难度的先验估计;训练过程中每隔一段时间自动化评估当前模型在所有数据集上的loss,以及在评测集上的生成效果,动态提升较难学习的数据集权重,保证模型在各个数据集上都有较佳的拟合效果。
30
+
31
+ ### 模型结构
32
+
33
+ 我们采用标准的 `Decoder-only` 结构设计了 **TeleChat** 模型,并在模型维度做了如下的一些改进:
34
+
35
+ - **位置编码**:我们使用 [Rotary Embedding](https://arxiv.org/pdf/2104.09864.pdf) 的位置编码方法,该方法将相对位置信息依赖集成到 self-attention 中,并且具有较好的位置外推性。Rotary Embedding还可以较好地与Flash-Attention v2 配合使用,将模型的训练速度提升约20%。
36
+ - **激活函数**:我们使用 [SwiGLU](https://arxiv.org/pdf/2002.05202.pdf) 激活函数来替代GELU激活函数 , 为了减少计算量,将`ffn_hidden_size`设置为小于原始SwiGLU中的4倍隐藏层大小。
37
+ - **层标准化**: 基于 [RMSNorm](https://arxiv.org/abs/1910.07467) 的 Pre-Normalization。
38
+ - **词嵌入层与输出层解耦**:我们将**TeleChat-12B-bot**的词嵌入层和输出lm head层参数分开,有助于增强训练稳定性和收敛性。
39
+
40
+
41
+ | | layer_num | hidden_size | ffn_hidden_size | head_num | tie_word_embeddings |
42
+ |-----| --------- | ----------- | --------------- | -------- | ----------------------- |
43
+ | 7B | 30 | 4096 | 12288 | 32 | 是 |
44
+ | 12B | 38 | 5120 | 12288 | 32 | 否 |
45
+
46
+ ---
47
+
48
+ 我们开源的TeleChat模型:
49
+ - 支持deepspeed微调,开源了基于deepspeed的训练代码,支持Zero并行显存优化,同时集成了FlashAttention2
50
+ - 多轮能力支持。开源了多轮数据构建方式,针对多轮模型训练集成了针对多轮的mask loss训练方式,更好的聚焦多轮答案,提升问答效果。
51
+ - 外推能力提升。开源了8K训练版本模型,采用NTK-aware外推和attention scaling外推方式,可以外推到96K。
52
+ - 具备较好的长文生成能力。在工作总结、工作计划、PPT大纲、申论、招标书、邮件、方案、周报、JD写作等长文写作任务上表现较好。
53
+
54
+
55
+ 本次发布版本和下载链接见下表
56
+
57
+ | 模型版本 | huggingface下载链接 |modelscope下载链接|
58
+ |----------|-----------------------------------------------------------------------|------------------------------------|
59
+ | 7B-FP16 | [TeleChat-7B-FP16-hf](https://huggingface.co/Tele-AI/Telechat-7B) |[TeleChat-7B-FP16-ms](https://modelscope.cn/models/TeleAI/telechat-7B) |
60
+ | 7B-int8 | [TeleChat-7B-int8-hf](https://huggingface.co/Tele-AI/Telechat-7B-int8)|[TeleChat-7B-int8-ms](https://modelscope.cn/models/TeleAI/telechat-7B-int8) |
61
+ | 7B-int4 | [TeleChat-7B-int4-hf](https://huggingface.co/Tele-AI/Telechat-7B-int4)|[TeleChat-7B-int4-ms](https://modelscope.cn/models/TeleAI/telechat-7B-int4) |
62
+ | 12B-FP16 | [TeleChat-12B-FP16-hf](https://huggingface.co/Tele-AI/TeleChat-12B)|[TeleChat-12B-FP16-ms](https://modelscope.cn/models/TeleAI/TeleChat-12B) |
63
+ | 12B-int8 | [TeleChat-12B-int8-hf](https://huggingface.co/Tele-AI/TeleChat-12B-int8)|[TeleChat-12B-int8-ms](https://modelscope.cn/models/TeleAI/TeleChat-12B-int8) |
64
+ | 12B-int4 | [TeleChat-12B-int4-hf](https://huggingface.co/Tele-AI/TeleChat-12B-int4)|[TeleChat-12B-int4-ms](https://modelscope.cn/models/TeleAI/TeleChat-12B-int4) |
65
+
66
+ **镜像下载**
67
+ 为了便于大家快速上手,我们提供了可运行的环境镜像,下载地址:[镜像下载](https://cloud.189.cn/web/share?code=vQFJRf7JBfmq) (访问码:ona6)
68
+
69
+ # 数据开源
70
+ ### 数据介绍
71
+ TeleChat-PTD 是由电信星辰大模型**TeleChat**预训练语料中抽取出的的综合性大规模中文数据集。数据主要来源于网页、书籍、官方媒体等。 我们使用规则+模型的方式进行了相关的过滤,并对数据进行了相似性去重,尽可能地提取出高质量地数据。
72
+
73
+ TeleChat-PTD 数据集大约公开了2.7亿条数据,数据由纯中文文本构成构成,原始大小约1TB,压缩后480G,共189个文件。数据集中已经去除了其它冗余信息。
74
+
75
+ ### 数据下载
76
+
77
+ huggingface下载地址:[TeleChat-PTD](https://huggingface.co/datasets/Tele-AI/TeleChat-PTD)
78
+
79
+ modelscope下载地址:[TeleChat-PTD](https://modelscope.cn/datasets/TeleAI/TeleChat-PTD)
80
+
81
+ 天翼云盘下载地址:[数据下载](https://cloud.189.cn/t/ia2QbaVzYf6z)(访问码:pkg8)
82
+
83
+ # 效果评测
84
+ TeleChat模型相比同规模模型在评测效果方面也有较好的表现,我们的评测集涵盖了包括MMLU、C-Eval、GAOKAO、AGIEval、CMMLU、 GSM8K、MATH、HumanEval、CHID等数据集,评测能力包括了自然语言理解、知识、数学计算和推理、代码生成等
85
+
86
+ ## 评测结果如下
87
+
88
+ | Model | MMLU | C-Eval | CMMLU | AGIEval | GAOKAO | GSM8K | MATH | HumanEval | CSL | CHID | EPRSTMT | BBH | HellaSwag |
89
+ |:--------------------|:--------:|:------:|:------:|:---------:|:---------:|:------:|:------:|:---------:|:---------:|:---------:|:--------:|:------:|:---------:|
90
+ | | 5-shot | 5-shot | 5-shot | zero-shot | zero-shot | 4-shot | 4-shot | zero-shot | zero-shot | zero-shot |zero-shot | 3-shot | zero-shot |
91
+ | LLaMA2-7B-chat | 46.2 | 31.9 | 31.5 | 28.5 | 16.1 | 26.3 | 3.9 | 12.2 | 58.8 | 44.1 | 57.5 | 35.6 | 74.1 |
92
+ | LLaMA2-13B-chat | 54.6 | 36.2 | 38.7 | 32.3 | 18.6 | 29.6 | 5.0 | 18.9 | 61.2 | 48.0 | 59.4 | 40.2 | 78.2 |
93
+ | ChatGLM2-6B-chat | 45.9 | 52.6 | 49.3 | 39.0 | 46.4 | 28.8 | 6.5 | 11.0 | 61.2 | 57.9 | 71.2 | 32.7 | 57.0 |
94
+ | ChatGLM3-6B-chat | 51.9 | 53.8 | 54 | 38.9 | 49.3 | 56.7 | 18.7 | 61 | 65.6 | 63.4 | 85 | 44.6 | 62.7 |
95
+ | Baichuan2-7B-chat | 52.8 | 55.6 | 54.0 | 35.3 | 39.7 | 32.8 | 6 | 13.4 | 60 | 75.2 | 87.5 | 35.8 | 61.6 |
96
+ | Baichuan2-13B-chat | 57 | 56.7 | 58.4 | 40 | 51.4 | 55.3 | 8.6 | 17.7 | 63.1 | 78.2 | 87.5 | 49.9 | 66.9 |
97
+ | Qwen-7B-chat | 56.6 | 59.3 | 59.5 | 41.3 | 63.3 | 52.5 | 10.3 | 26.2 | 63.1 | 72.3 | 88.8 | 46.9 | 59.9 |
98
+ | Qwen-14B-chat | 66.4 | 71.7 | 70.0 | 47.3 | 76.5 | 61.0 | 26.8 | 36.6 | 55.6 | 72.3 | 91.2 | 58.0 | 65.2 |
99
+ | TeleChat-7B-chat | **60.5** | **64.6** | **64.3** | **46.8** | **59** | **36.7** | **10.3** | **20.1** | **66.8** | **88.0** | **87.5** | **19.5** | **36.7** |
100
+ | TeleChat-12B-chat | **73.3** | **66.6** | **74.2** | **51.7** | **53.1** | **57.2** | **16.0** | **22.0** | **60.6** | **83.2** | **86.3** | **52.2** | **71.5** |
101
+
102
+ 说明:CMMLU、AGIEval、GAOKAO、CSL、CHID、EPRSTMT均基于[OpenCompass](https://github.com/open-compass/OpenCompass/)平台提供的评测方法进行评估,而对于对比模型,我们同时参考了官方汇报结果和OpenCompass结果。我们使用了自己的评测脚本评测MMLU与CEVAL榜单,具体方法见`evaluation/`文件夹。
103
+
104
+ # 模型推理
105
+
106
+ ```python
107
+ import os
108
+ import torch
109
+ from modelscope import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
110
+ os.environ["CUDA_VISIBLE_DEVICES"] = '0'
111
+ tokenizer = AutoTokenizer.from_pretrained('TeleAI/TeleChat-12B')
112
+ model = AutoModelForCausalLM.from_pretrained('TeleAI/TeleChat-12B', trust_remote_code=True, device_map="auto", torch_dtype=torch.float16)
113
+ generate_config = GenerationConfig.from_pretrained('TeleAI/TeleChat-12B')
114
+ question="生抽与老抽的区别?"
115
+ answer, history = model.chat(tokenizer = tokenizer, question=question, history=[], generation_config=generate_config, stream=False)
116
+ print(answer)
117
+ 生抽和老抽是两种不同的酱油,它们的区别如下:
118
+
119
+ 1. 原料不同:生抽是用大豆、小麦等谷物为原料制成的;而老抽则是用豆酱、面酱等发酵后的调味品为原料制成的。
120
+
121
+ 2. 制作工艺不同:生抽是通过将大豆浸泡在水中,然后经过蒸煮、发酵等过程制成的;而老抽则是在生抽的基础上加入一定比例的盐、糖、味精等调料,再进行发酵制成的。
122
+
123
+ 3. 口感和风味不同:生抽具有咸鲜的味道,口感比较清爽;而老抽则具有特殊的香味和味道,口感相对较重。
124
+
125
+ 总的来说,生抽和老抽都是酱油的不同种类,它们在原料、制作工艺和口感等方面都有所不同。
126
+ ```
127
+
128
+
129
+
130
+ # 声明、协议、引用
131
+ ### 声明
132
+ 我们在此声明,不要使用TeleChat模型及其衍生模型进行任何危害国家社会安全或违法的活动。同时,我们也要求使用者不要将TeleChat模型用于没有安全审查和备案的互联网服务。我们希望所有使用者遵守上述原则,确保科技发展在合法合规的环境下进行。
133
+
134
+ 我们已经尽我们所能,来确保模型训练过程中使用的数据的合规性。然而,尽管我们已经做出了巨大的努力,但由于模型和数据的复杂性,仍有可能存在一些无法预见的问题。因此,如果由于使用TeleChat开源模型而导致的任何问题,包括但不限于数据安全问题、公共舆论风险,或模型被误导、滥用、传播或不当利用所带来的任何风险和问题,我们将不承担任何责任。
135
+
136
+ ### 协议
137
+ 社区使用 TeleChat 模型需要遵循《[TeleChat模型社区许可协议](./TeleChat模型社区许可协议.pdf)》。TeleChat模型支持商业用途,如果您计划将 TeleChat 模型或其衍生品用于商业目的,您需要通过以下联系邮箱 [email protected],提交《TeleChat模型社区许可协议》要求的申请材料。审核通过后,将特此授予您一个非排他性、全球性、不可转让、不可再许可、可撤销的商用版权许可。
138
+
139
+ ### 引用
140
+ 如需引用我们的工作,请使用如下 reference:
141
+ ```
142
+ @misc{wang2024telechat,
143
+ title={TeleChat Technical Report},
144
+ author={Zihan Wang and Xinzhang Liu and Shixuan Liu and Yitong Yao and Yuyao Huang and Zhongjiang He and Xuelong Li and Yongxiang Li and Zhonghao Che and Zhaoxi Zhang and Yan Wang and Xin Wang and Luwen Pu and Huihan Xu and Ruiyu Fang and Yu Zhao and Jie Zhang and Xiaomeng Huang and Zhilong Lu and Jiaxin Peng and Wenjun Zheng and Shiquan Wang and Bingkai Yang and Xuewei he and Zhuoru Jiang and Qiyi Xie and Yanhan Zhang and Zhongqiu Li and Lingling Shi and Weiwei Fu and Yin Zhang and Zilu Huang and Sishi Xiong and Yuxiang Zhang and Chao Wang and Shuangyong Song},
145
+ year={2024},
146
+ eprint={2401.03804},
147
+ archivePrefix={arXiv},
148
+ primaryClass={cs.CL}
149
+ }
150
+ ```
config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "apply_residual_connection_post_layernorm": false,
3
+ "architectures": [
4
+ "TelechatForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_telechat.TelechatConfig",
8
+ "AutoModelForCausalLM": "modeling_telechat.TelechatForCausalLM"
9
+ },
10
+ "attention_dropout": 0.0,
11
+ "attention_softmax_in_fp32": true,
12
+ "bias_dropout_fusion": true,
13
+ "bos_token_id": 1,
14
+ "eos_token_id": 2,
15
+ "hidden_dropout": 0.0,
16
+ "hidden_size": 5120,
17
+ "initializer_range": 0.02,
18
+ "layer_norm_epsilon": 1e-05,
19
+ "masked_softmax_fusion": true,
20
+ "model_type": "telechat",
21
+ "n_head": 32,
22
+ "n_inner": null,
23
+ "n_layer": 38,
24
+ "offset_alibi": 100,
25
+ "pad_token_id": 3,
26
+ "pretraining_tp": 2,
27
+ "seq_length": 8192,
28
+ "skip_bias_add": true,
29
+ "skip_bias_add_qkv": false,
30
+ "slow_but_exact": false,
31
+ "transformers_version": "4.24.0",
32
+ "unk_token_id": 0,
33
+ "use_cache": true,
34
+ "vocab_size": 120000,
35
+ "ffn_hidden_size": 12288,
36
+ "flash_attn":true,
37
+ "tie_word_embeddings":false,
38
+ "training_seqlen":8192,
39
+ "logn":false,
40
+ "semi_causal":false,
41
+ "embed_layernorm":false
42
+ }
43
+
configuration_telechat.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 the Big Science Workshop and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ Telechat configuration"""
17
+
18
+ from packaging import version
19
+ from collections import OrderedDict
20
+ from transformers.utils import is_torch_available, logging
21
+ from transformers.configuration_utils import PretrainedConfig
22
+ from typing import TYPE_CHECKING, Any, List, Mapping, Optional
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ class TelechatConfig(PretrainedConfig):
27
+ """
28
+ Args:
29
+ vocab_size (`int`, *optional*, defaults to 160256): Vocabulary size of the Telechat model.
30
+ hidden_size (`int`, *optional*, defaults to 4096): Dimensionality of the embeddings and hidden states.
31
+ ffn_hidden_size (`int`, *optional*, defaults to 12288): Dimensionality of the feed-forward hidden states.
32
+ n_layer (`int`, *optional*, defaults to 30): Number of hidden layers in the Transformer
33
+ n_head (`int`, *optional*, defaults to 32): Number of attention heads for each attention layer.
34
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): The epsilon to use in the layer normalization layers.
35
+ initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
36
+ apply_residual_connection_post_layernorm (`bool`, *optional*, defaults to `False`): If enabled, use the layer norm of the hidden states as the residual in the transformer blocks
37
+ hidden_dropout (`float`, *optional*, defaults to 0.0): Dropout rate of the dropout function on the bias dropout.
38
+ attention_dropout (`float`, *optional*, defaults to 0.0): Dropout rate applied to the attention probs
39
+ use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions.
40
+ training_seqlen (`int`, *optional*, defaults to 8192): Sequence length during last finetuning.
41
+ logn (`bool`, *optional*, defaults to `True`): Whether or not to use logN during extrapolation.
42
+ embed_layernorm (`bool`, *optional*, defaults to `True`): Whether or not to use embedding layernorm.
43
+
44
+ """
45
+
46
+ model_type = "telechat"
47
+ keys_to_ignore_at_inference = ["past_key_values"]
48
+ attribute_map = {
49
+ "num_hidden_layers": "n_layer",
50
+ "num_attention_heads": "n_head",
51
+ }
52
+
53
+ def __init__(
54
+ self,
55
+ vocab_size=160256,
56
+ hidden_size=4096,
57
+ n_layer=30,
58
+ n_head=32,
59
+ layer_norm_epsilon=1e-5,
60
+ initializer_range=0.02,
61
+ use_cache=True,
62
+ bos_token_id=1,
63
+ eos_token_id=2,
64
+ apply_residual_connection_post_layernorm=False,
65
+ hidden_dropout=0.0,
66
+ attention_dropout=0.0,
67
+ ffn_hidden_size=12288,
68
+ training_seqlen = 8192,
69
+ logn = True,
70
+ embed_layernorm = False,
71
+ **kwargs,
72
+ ):
73
+ self.vocab_size = vocab_size
74
+ n_embed = kwargs.pop("n_embed", None)
75
+ self.hidden_size = hidden_size if n_embed is None else n_embed
76
+ self.n_layer = n_layer
77
+ self.n_head = n_head
78
+ self.layer_norm_epsilon = layer_norm_epsilon
79
+ self.initializer_range = initializer_range
80
+ self.use_cache = use_cache
81
+ self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
82
+ self.hidden_dropout = hidden_dropout
83
+ self.attention_dropout = attention_dropout
84
+ self.bos_token_id = bos_token_id
85
+ self.eos_token_id = eos_token_id
86
+ self.logn = logn
87
+ self.ffn_hidden_size = ffn_hidden_size
88
+ self.training_seqlen = training_seqlen
89
+ self.embed_layernorm = embed_layernorm
90
+
91
+
92
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
93
+
generation_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_length": 8192,
3
+ "do_sample": false,
4
+ "use_cache": true,
5
+ "temperature": 0.3,
6
+ "top_k": 5,
7
+ "top_p": 0.85,
8
+ "repetition_penalty": 1.01,
9
+ "pad_token_id": 3,
10
+ "bos_token_id": 1,
11
+ "eos_token_id": 2,
12
+ "user_token_id": 20,
13
+ "bot_token_id": 21
14
+ }
generation_utils.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from collections import deque
3
+ from queue import Queue
4
+ import copy
5
+
6
+
7
+ class History:
8
+
9
+ def __init__(self, tokenizer, history):
10
+ '''
11
+ init from a list of dict
12
+ '''
13
+ # use deque to meet some special situation
14
+ self.input_history = deque()
15
+ self.tokenizer = tokenizer
16
+ if history:
17
+ self._transfer_from_list(history)
18
+
19
+ def _transfer_from_list(self, history):
20
+ for message in history:
21
+ content = message.get("content")
22
+ # the token result may not be equal to the result model gen
23
+ message.update(self.tokenizer(content))
24
+ self.input_history.append(message)
25
+
26
+ def append(self, message):
27
+ content = message.get("content")
28
+ if "input_ids" not in message or "attention_mask" not in message:
29
+ message.update(self.tokenizer(content))
30
+ self.input_history.append(message)
31
+
32
+ def append_left(self, message):
33
+ content = message.get("content")
34
+ if "input_ids" not in message or "attention_mask" not in message:
35
+ message.update(self.tokenizer(content))
36
+ self.input_history.appendleft(message)
37
+
38
+ def pop(self):
39
+ x = self.input_history.pop()
40
+ return x
41
+
42
+ def pop_left(self):
43
+ x = self.pop_left()
44
+ return x
45
+
46
+ def update(self, message):
47
+ self.input_history.pop()
48
+ self.append(message)
49
+
50
+ def __len__(self):
51
+ return self.input_history.__len__()
52
+
53
+ def __str__(self):
54
+ return self.input_history.__str__()
55
+
56
+ def __copy__(self):
57
+ new_instance = type(self)(self.tokenizer, [])
58
+ new_instance.input_history = copy.copy(self.input_history)
59
+ return new_instance
60
+
61
+ def __deepcopy__(self, memodict={}):
62
+ new_instance = type(self)(self.tokenizer, [])
63
+ new_instance.input_history = copy.deepcopy(self.input_history)
64
+ return new_instance
65
+
66
+
67
+ class TelechatIterTextStreamer:
68
+ """
69
+ With reference to the TextIterStreamers in transformers, we have rewritten this class
70
+ """
71
+
72
+ def __init__(
73
+ self, tokenizer, history: History = None, skip_prompt: bool = False, timeout: Optional[float] = None,
74
+ **decode_kwargs
75
+ ):
76
+
77
+ self.tokenizer = tokenizer
78
+ self.history = history
79
+ self.skip_prompt = skip_prompt
80
+ self.timeout = timeout
81
+ self.decode_kwargs = decode_kwargs
82
+
83
+ self.text_queue = Queue()
84
+ self.cache_time = 0
85
+ self.text_until = ""
86
+ self.token_until = []
87
+ self.stop_signal = None
88
+ self.next_tokens_are_prompt = True
89
+
90
+ self.history.append({"role": "bot", "content": self.text_until})
91
+
92
+ def put(self, value):
93
+ """
94
+ put printable text into queue
95
+ """
96
+ if len(value.shape) > 1 and value.shape[0] > 1:
97
+ raise ValueError("TextStreamer only supports batch size 1")
98
+ elif len(value.shape) > 1:
99
+ value = value[0]
100
+
101
+ if self.skip_prompt and self.next_tokens_are_prompt:
102
+ self.next_tokens_are_prompt = False
103
+ return
104
+
105
+ if value[-1] == self.tokenizer.eos_token_id:
106
+ return
107
+
108
+ # there may be some smart way to decode.
109
+ self.token_until.extend(value.tolist())
110
+ text = self.tokenizer.decode(self.token_until, **self.decode_kwargs)
111
+
112
+
113
+ if self._is_printable(text) or self.cache_time >= 6:
114
+ output_text = text[len(self.text_until):]
115
+ self.text_until = text
116
+
117
+ else:
118
+ self.cache_time+=1
119
+ return
120
+
121
+ self.on_finalized_text(output_text)
122
+
123
+ def end(self):
124
+ """Flushes any remaining cache and prints a newline to stdout."""
125
+ # Flush the cache, if it exists
126
+ text = self.tokenizer.decode(self.token_until, **self.decode_kwargs)
127
+ output_text = text[len(self.text_until):]
128
+ self.text_until = text
129
+ self.on_finalized_text(output_text, stream_end=True)
130
+ self.clear_cache()
131
+
132
+ def clear_cache(self):
133
+ self.cache_time = 0
134
+ self.token_until = []
135
+ self.text_until = ""
136
+ self.history = None
137
+ self.next_tokens_are_prompt = True
138
+
139
+ def on_finalized_text(self, text: str, stream_end: bool = False):
140
+ """Put the text tuple in the queue."""
141
+ self.history.update({"role": "bot", "content": self.text_until, "input_ids": self.token_until,
142
+ "attention_mask": [1] * len(self.token_until)})
143
+ self.text_queue.put((text, self.history), timeout=self.timeout)
144
+ if stream_end:
145
+ self.text_queue.put((self.stop_signal, self.history), timeout=self.timeout)
146
+
147
+ @staticmethod
148
+ def _is_printable(cp):
149
+ """Checks whether tokens can be decoded or not"""
150
+ if "�" in cp:
151
+ return False
152
+ return True
153
+
154
+ def __iter__(self):
155
+ return self
156
+
157
+ def __next__(self):
158
+ value_now, history_until = self.text_queue.get(timeout=self.timeout)
159
+ if value_now == self.stop_signal:
160
+ raise StopIteration()
161
+ else:
162
+ return value_now, history_until
modeling_telechat.py ADDED
@@ -0,0 +1,910 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 HuggingFace Inc. team and BigScience workshop.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
17
+
18
+ # Copyright (c) 2021 EleutherAI
19
+ # This file is based on code by the authors denoted below and has been modified from its original version.
20
+ #
21
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
22
+ #
23
+ # Licensed under the Apache License, Version 2.0 (the "License");
24
+ # you may not use this file except in compliance with the License.
25
+ # You may obtain a copy of the License at
26
+ #
27
+ # http://www.apache.org/licenses/LICENSE-2.0
28
+ #
29
+ # Unless required by applicable law or agreed to in writing, software
30
+ # distributed under the License is distributed on an "AS IS" BASIS,
31
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32
+ # See the License for the specific language governing permissions and
33
+ # limitations under the License.
34
+
35
+
36
+ """PyTorch TELECHAT model."""
37
+
38
+ import warnings
39
+ from typing import Optional, Tuple, Union, List, Dict
40
+ from threading import Thread
41
+
42
+ import torch
43
+ import math
44
+ import copy
45
+ from torch import nn
46
+ import torch.utils.checkpoint
47
+ from torch.nn import functional as F
48
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
49
+ from transformers.modeling_outputs import (
50
+ BaseModelOutputWithPastAndCrossAttentions,
51
+ CausalLMOutputWithCrossAttentions
52
+ )
53
+ from transformers.modeling_utils import PreTrainedModel
54
+ from transformers.utils import logging
55
+ from transformers import GenerationConfig
56
+
57
+ from .configuration_telechat import TelechatConfig
58
+ from .generation_utils import History, TelechatIterTextStreamer
59
+
60
+ logger = logging.get_logger(__name__)
61
+
62
+ _CHECKPOINT_FOR_DOC = "telechat"
63
+ _CONFIG_FOR_DOC = "TelechatConfig"
64
+
65
+ TELECHAT_PRETRAINED_MODEL_ARCHIVE_LIST = []
66
+
67
+ try:
68
+ from einops import rearrange
69
+ except ImportError:
70
+ rearrange = None
71
+
72
+ use_flash_attn = True
73
+ try:
74
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_func
75
+ except ImportError:
76
+ try:
77
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func
78
+ except ImportError:
79
+ flash_attn_unpadded_func = None
80
+
81
+
82
+ class RotaryEmbedding(torch.nn.Module):
83
+ # Extracted from: https://github.com/EleutherAI/gpt-neox
84
+ def __init__(self, dim, config, base=10000):
85
+ super().__init__()
86
+ self.config = config
87
+ self.dim = dim
88
+ self.base = base
89
+ self.max_seq_len_cached = None
90
+ self.cos_cached = None
91
+ self.sin_cached = None
92
+
93
+ def get_mscale(self, scale=1):
94
+ if scale <= 1:
95
+ return 1.0
96
+ return 0.1 * math.log(scale) + 1.0
97
+
98
+ def get_ntk_alpha(self, true_seq_len):
99
+ context_value = math.log(true_seq_len / 4096, 2) + 1
100
+ ntk_alpha = 2 ** math.ceil(context_value) - 1
101
+ ntk_alpha = max(ntk_alpha, 1)
102
+ return ntk_alpha
103
+
104
+ def forward(self, x, dtype, seq_dim=0):
105
+ seq_len = x.shape[seq_dim]
106
+ self.mscale = 1.0
107
+ if not self.training:
108
+ seq_len = max(seq_len, self.config.training_seqlen)
109
+ self.mscale = float(self.get_mscale(seq_len / self.config.training_seqlen))
110
+ ntk_alpha = self.get_ntk_alpha(seq_len)
111
+ base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
112
+ self.inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, device=x.device).float() / self.dim))
113
+ self.max_seq_len_cached = seq_len
114
+ t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
115
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
116
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
117
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
118
+ # if self.precision == torch.bfloat16:
119
+ emb = emb.float() if dtype == torch.bfloat16 else emb
120
+ # [sx, 1 (b * np), hn]
121
+ self.cos_cached = self.mscale * emb.cos()[:, None, :].to(dtype)
122
+ self.sin_cached = self.mscale * emb.sin()[:, None, :].to(dtype)
123
+ return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
124
+
125
+
126
+ # rotary pos emb helpers:
127
+ def rotate_half(x):
128
+ x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
129
+ return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
130
+
131
+
132
+ def apply_rotary_pos_emb_torch(q, k, cos, sin, offset: int = 0): # jitting fails with bf16
133
+ cos, sin = cos[offset:q.shape[0] + offset, ...], sin[offset:q.shape[0] + offset, ...]
134
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
135
+
136
+
137
+ class MixedFusedRMSNorm(nn.Module):
138
+ # Extracted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
139
+ def __init__(self, hidden_size, eps=1e-6):
140
+ super().__init__()
141
+ self.weight = nn.Parameter(torch.ones(hidden_size))
142
+ self.variance_epsilon = eps
143
+
144
+ def forward(self, hidden_states):
145
+ input_dtype = hidden_states.dtype
146
+ hidden_states = hidden_states.to(torch.float32)
147
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
148
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
149
+ return self.weight * hidden_states.to(input_dtype)
150
+
151
+
152
+ class FlashSelfAttention(torch.nn.Module):
153
+ # Extracted from https://github.com/microsoft/Megatron-DeepSpeed/blob/main/megatron/model/transformer.py
154
+ """Implement the scaled dot product attention with softmax.
155
+ Arguments
156
+ ---------
157
+ softmax_scale: The temperature to use for the softmax attention.
158
+ (default: 1/sqrt(d_keys) where d_keys is computed at
159
+ runtime)
160
+ attention_dropout: The dropout rate to apply to the attention
161
+ (default: 0.0)
162
+ """
163
+
164
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
165
+ device=None, dtype=None):
166
+ super().__init__()
167
+ assert flash_attn_unpadded_func is not None, ('Please install FlashAttention first, '
168
+ 'e.g., with pip install flash-attn')
169
+ assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
170
+ self.causal = causal
171
+ self.softmax_scale = softmax_scale
172
+ self.dropout_p = attention_dropout
173
+
174
+ def forward(self, q, k, v):
175
+ """Implements the multihead softmax attention.
176
+ Arguments
177
+ ---------
178
+ q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
179
+ """
180
+ assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v)))
181
+ assert all((i.is_cuda for i in (q, k, v)))
182
+
183
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
184
+ seqlen_k = k.shape[1]
185
+
186
+ q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
187
+ cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
188
+ device=q.device)
189
+ if self.training:
190
+ # during training q,k,v always have same seqlen
191
+ assert seqlen_k == seqlen_q
192
+
193
+ is_causal = self.causal
194
+ cu_seqlens_k = cu_seqlens_q
195
+ dropout_p = self.dropout_p
196
+ else:
197
+ # turn off FA causal mask after first inference autoregressive iteration
198
+ # only on first autoregressive step q,k,v have same seqlen
199
+ is_causal = seqlen_q == seqlen_k
200
+ cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
201
+ device=q.device)
202
+ dropout_p = 0
203
+
204
+ output = flash_attn_unpadded_func(
205
+ q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
206
+ dropout_p=dropout_p,
207
+ softmax_scale=self.softmax_scale, causal=is_causal
208
+ )
209
+
210
+ output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
211
+ return output
212
+
213
+
214
+ def _make_causal_mask(
215
+ input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
216
+ ) -> torch.BoolTensor:
217
+ """
218
+ Make causal mask used for self-attention.
219
+ """
220
+ batch_size, target_length = input_ids_shape
221
+ mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
222
+ # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
223
+ seq_ids = torch.arange(target_length, device=device)
224
+ mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
225
+
226
+ if past_key_values_length > 0:
227
+ mask[:, :past_key_values_length] = False
228
+
229
+ expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
230
+ return expanded_mask
231
+
232
+
233
+ def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
234
+ """
235
+ Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
236
+ """
237
+ batch_size, src_length = mask.shape
238
+ tgt_length = tgt_length if tgt_length is not None else src_length
239
+
240
+ expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
241
+ return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
242
+
243
+
244
+ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
245
+ """
246
+ Dropout add function
247
+
248
+ Args:
249
+ x (`torch.tensor`, *required*):
250
+ input tensor
251
+ residual (`torch.tensor`, *required*):
252
+ residual tensor
253
+ prob (`float`, *required*):
254
+ dropout probability
255
+ training (`bool`, *required*):
256
+ training mode
257
+ """
258
+ out = F.dropout(x, p=prob, training=training)
259
+ out = residual + out
260
+ return out
261
+
262
+
263
+ def telechat_gelu_forward(x: torch.Tensor) -> torch.Tensor:
264
+ """
265
+ Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to
266
+ make the model jitable.
267
+
268
+ Args:
269
+ x (`torch.tensor`, *required*):
270
+ input hidden states
271
+ """
272
+ return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
273
+
274
+
275
+ def telechat_gelu_back(g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
276
+ """
277
+ gradient of tanh approximation of gelu gradient of actual gelu is: 0.5 * (1. + torch.erf(x * 0.70710678)) +
278
+ 0.3989423 * x * torch.exp(-0.5 * x * x)
279
+
280
+ Args:
281
+ g (`torch.tensor`, *required*):
282
+ gradient output tensor
283
+ x (`torch.tensor`, *required*):
284
+ input tensor
285
+ """
286
+ x = x[0] # x is a tuple of 1 element, needs to unpack it first
287
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
288
+ # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
289
+ ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
290
+ return ff * g
291
+
292
+
293
+ class GeLUFunction(torch.autograd.Function):
294
+ @staticmethod
295
+ def forward(ctx, input: torch.Tensor) -> torch.Tensor:
296
+ ctx.save_for_backward(input)
297
+ return telechat_gelu_forward(input)
298
+
299
+ @staticmethod
300
+ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
301
+ input = ctx.saved_tensors
302
+ tmp = telechat_gelu_back(grad_output, input)
303
+ return tmp
304
+
305
+
306
+ class TelechatGelu(nn.Module):
307
+ """
308
+ TelechatBiasGelu wrapper function that make use of the simple function on inference mode to make the model
309
+ torchscriptable and use the autograd function in training mode to get the accurate results of the gradients Partly
310
+ copied from Megatron-DeepSpeed code and adapted for our needs
311
+
312
+ See here why autograd functions are not torchscriptable: https://github.com/pytorch/pytorch/issues/22329
313
+ """
314
+
315
+ def __init__(self):
316
+ super().__init__()
317
+
318
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
319
+ if self.training:
320
+ return GeLUFunction.apply(x)
321
+ else:
322
+ return telechat_gelu_forward(x)
323
+
324
+
325
+ class TelechatAttention(nn.Module):
326
+ def __init__(self, config: TelechatConfig, layer_idx):
327
+ super().__init__()
328
+ self.kv_cache = None
329
+ self.layer_idx = layer_idx
330
+
331
+ self.hidden_size = config.hidden_size
332
+ self.num_heads = config.n_head
333
+ self.head_dim = self.hidden_size // self.num_heads
334
+ self.split_size = self.hidden_size
335
+ self.hidden_dropout = config.hidden_dropout
336
+ self.config = config
337
+
338
+ if self.head_dim * self.num_heads != self.hidden_size:
339
+ raise ValueError(
340
+ f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
341
+ f" {self.num_heads})."
342
+ )
343
+
344
+ # Layer-wise attention scaling
345
+ self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
346
+ self.beta = 1.0
347
+
348
+ self.num_key_value_heads = self.num_heads
349
+ kv_projection_size = self.head_dim * self.num_key_value_heads
350
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
351
+ self.query = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
352
+ self.key_value = nn.Linear(self.hidden_size, kv_projection_size * 2, bias=False)
353
+ self.dense = nn.Linear(self.hidden_size, self.hidden_size)
354
+ self.attention_dropout = nn.Dropout(config.attention_dropout)
355
+ self.rotary_emb = RotaryEmbedding(self.head_dim, config=config)
356
+
357
+ self.core_attention_flash = FlashSelfAttention(
358
+ causal=True, attention_dropout=config.attention_dropout
359
+ )
360
+
361
+ self.last_key_layer = None
362
+
363
+ def repeat_kv(self, hidden_states, n_rep):
364
+ slen, batch, num_key_value_heads_per_partition, head_dim = hidden_states.shape
365
+ if n_rep == 1:
366
+ return hidden_states
367
+ hidden_states = hidden_states[:, :, :, None, :].expand(slen, batch, num_key_value_heads_per_partition, n_rep,
368
+ head_dim)
369
+ return hidden_states.reshape(slen, batch, num_key_value_heads_per_partition * n_rep, head_dim)
370
+
371
+ def split_tensor_along_last_dim(self,
372
+ tensor: torch.Tensor,
373
+ num_partitions: int,
374
+ contiguous_split_chunks: bool = False,
375
+ ):
376
+
377
+ # Get the size and dimension.
378
+ last_dim = tensor.dim() - 1
379
+ last_dim_size = tensor.size()[last_dim] // num_partitions
380
+ # Split.
381
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
382
+ # Note: torch.split does not create contiguous tensors by default.
383
+ if contiguous_split_chunks:
384
+ return tuple(chunk.contiguous() for chunk in tensor_list)
385
+
386
+ return tensor_list
387
+
388
+ def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
389
+ batch_size_and_num_heads, seq_length, _ = x.shape
390
+ batch_size = batch_size_and_num_heads // self.num_heads
391
+ x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
392
+ x = x.permute(0, 2, 1, 3)
393
+ return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
394
+
395
+ def forward(
396
+ self,
397
+ hidden_states: torch.Tensor,
398
+ residual: torch.Tensor,
399
+ attention_mask: torch.Tensor,
400
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
401
+ use_cache: bool = False,
402
+ output_attentions: bool = False,
403
+ ):
404
+ hidden_states = hidden_states.transpose(1, 0)
405
+ query_layer = self.query(hidden_states)
406
+ new_tensor_shape = query_layer.size()[:-1] + \
407
+ (self.num_heads,
408
+ self.head_dim)
409
+ query_layer = query_layer.view(*new_tensor_shape)
410
+
411
+ mixed_kv_layer = self.key_value(hidden_states)
412
+ new_tensor_shape = mixed_kv_layer.size()[:-1] + \
413
+ (self.num_key_value_heads,
414
+ 2 * self.head_dim)
415
+ mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
416
+ (key_layer, value_layer) = self.split_tensor_along_last_dim(mixed_kv_layer, 2)
417
+
418
+ output_size = (query_layer.size(1),
419
+ query_layer.size(2),
420
+ query_layer.size(0),
421
+ key_layer.size(0))
422
+
423
+ query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
424
+ key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
425
+
426
+ apply_rotary_fn = apply_rotary_pos_emb_torch
427
+
428
+ seq_len = key_layer.shape[0]
429
+ offset = 0
430
+
431
+ if use_cache and layer_past != None:
432
+ past_key, past_value = layer_past
433
+ offset = past_key.shape[0]
434
+ seq_len += offset
435
+
436
+ cos, sin = self.rotary_emb(value_layer, dtype=value_layer.dtype)
437
+
438
+ query_layer, key_layer = apply_rotary_fn(query_layer, key_layer, cos, sin, offset=offset)
439
+ if use_cache:
440
+ if layer_past != None:
441
+ past_key, past_value = layer_past
442
+ key_layer = torch.cat((past_key, key_layer[-1, ...].unsqueeze(0)), dim=0)
443
+ value_layer = torch.cat((past_value, value_layer[-1, ...].unsqueeze(0)), dim=0)
444
+ layer_past = key_layer, value_layer
445
+ s, bz, head, dim = value_layer.shape
446
+ s_key = key_layer.shape[0]
447
+ s_query = query_layer.shape[0]
448
+ query_layer = query_layer.reshape((s_query, bz, head, dim))
449
+ key_layer = key_layer.reshape((s_key, bz, head, dim))
450
+
451
+ if self.config.flash_attn:
452
+ q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous() for x in
453
+ (query_layer, key_layer, value_layer)]
454
+ context_layer = self.core_attention_flash(q, k, v)
455
+ context_layer = rearrange(context_layer, 'b s h d -> b s (h d)').contiguous()
456
+ else:
457
+ ##[sq, b, np, hn] -> [sq, b * np, hn]
458
+ query_layer = query_layer.reshape(s_query, bz * self.num_heads, dim)
459
+ # [sk, b, np, hn] -> [sk, b * np, hn]
460
+ key_layer = key_layer.reshape(s_key, bz * self.num_heads, dim)
461
+ matmul_result = self.inv_norm_factor * torch.einsum('bik,bkj->bij', query_layer.transpose(0, 1),
462
+ key_layer.transpose(0, 1).transpose(1, 2))
463
+
464
+ attention_scores = matmul_result.view(bz, self.num_heads, s_query, s_key)
465
+
466
+ input_dtype = attention_scores.dtype
467
+ if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
468
+ attention_scores = attention_scores.to(torch.float)
469
+ attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
470
+ attention_probs = F.softmax(attn_weights, dim=-1).to(input_dtype) ##dtype = torch.float32
471
+ attention_probs = self.attention_dropout(attention_probs)
472
+ attention_probs_reshaped = attention_probs.view(bz * self.num_heads, s_query, s_key)
473
+
474
+ value_layer = value_layer.reshape(s_key, bz * self.num_heads, dim)
475
+ context_layer = torch.bmm(attention_probs_reshaped, value_layer.transpose(0, 1))
476
+ context_layer = self._merge_heads(context_layer)
477
+
478
+ output_tensor = self.dense(context_layer)
479
+
480
+ output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
481
+ present = None
482
+ outputs = (output_tensor, present)
483
+ if output_attentions:
484
+ outputs += (attention_probs,)
485
+
486
+ return output_tensor, layer_past
487
+
488
+
489
+ class TelechatMLP(nn.Module):
490
+ def __init__(self, config: TelechatConfig):
491
+ super().__init__()
492
+ hidden_size = config.hidden_size
493
+ self.gate_proj = nn.Linear(hidden_size, config.ffn_hidden_size, bias=False)
494
+ self.up_proj = nn.Linear(hidden_size, config.ffn_hidden_size, bias=False)
495
+ self.down_proj = nn.Linear(config.ffn_hidden_size, hidden_size, bias=True)
496
+ self.hidden_dropout = config.hidden_dropout
497
+
498
+ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
499
+ intermediate_output = self.down_proj(F.silu(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
500
+ output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
501
+ return output
502
+
503
+
504
+ class TelechatBlock(nn.Module):
505
+ def __init__(self, config: TelechatConfig, layer_idx):
506
+ super().__init__()
507
+ hidden_size = config.hidden_size
508
+
509
+ self.input_layernorm = MixedFusedRMSNorm(hidden_size, eps=config.layer_norm_epsilon)
510
+ self.num_heads = config.n_head
511
+ self.layer_idx = layer_idx
512
+ self.self_attention = TelechatAttention(config, layer_idx)
513
+ self.post_attention_layernorm = MixedFusedRMSNorm(hidden_size, eps=config.layer_norm_epsilon)
514
+
515
+ self.mlp = TelechatMLP(config)
516
+
517
+ self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
518
+ self.hidden_dropout = config.hidden_dropout
519
+
520
+ def forward(
521
+ self,
522
+ hidden_states: torch.Tensor,
523
+ attention_mask: torch.Tensor,
524
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
525
+ use_cache: bool = False,
526
+ output_attentions: bool = False,
527
+ ):
528
+ layernorm_output = self.input_layernorm(hidden_states)
529
+ if self.apply_residual_connection_post_layernorm:
530
+ residual = layernorm_output
531
+ else:
532
+ residual = hidden_states
533
+
534
+ attn_outputs = self.self_attention(
535
+ layernorm_output,
536
+ residual,
537
+ layer_past=layer_past,
538
+ attention_mask=attention_mask,
539
+ use_cache=use_cache,
540
+ output_attentions=output_attentions,
541
+ )
542
+
543
+ attention_output = attn_outputs[0]
544
+ outputs = attn_outputs[1:]
545
+ layernorm_output = self.post_attention_layernorm(attention_output)
546
+
547
+ if self.apply_residual_connection_post_layernorm:
548
+ residual = layernorm_output
549
+ else:
550
+ residual = attention_output
551
+ output = self.mlp(layernorm_output, residual)
552
+
553
+ if use_cache:
554
+ outputs = (output,) + outputs
555
+ else:
556
+ outputs = (output,) + outputs[1:]
557
+
558
+ return outputs
559
+
560
+
561
+ class TelechatPreTrainedModel(PreTrainedModel):
562
+ config_class = TelechatConfig
563
+ base_model_prefix = "transformer"
564
+ supports_gradient_checkpointing = True
565
+ _no_split_modules = ["TelechatBlock"]
566
+ _skip_keys_device_placement = "past_key_values"
567
+
568
+ def __init__(self, *inputs, **kwargs):
569
+ super().__init__(*inputs, **kwargs)
570
+
571
+ def _init_weights(self, module: nn.Module):
572
+ """Initialize the weights."""
573
+ if isinstance(module, nn.Linear):
574
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
575
+ if module.bias is not None:
576
+ module.bias.data.zero_()
577
+
578
+ elif isinstance(module, nn.Embedding):
579
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
580
+ if module.padding_idx is not None:
581
+ module.weight.data[module.padding_idx].zero_()
582
+
583
+ elif isinstance(module, LayerNorm):
584
+ module.bias.data.zero_()
585
+ module.weight.data.fill_(1.0)
586
+
587
+ def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
588
+ if isinstance(module, TelechatModel):
589
+ module.gradient_checkpointing = value
590
+
591
+
592
+ class TelechatModel(TelechatPreTrainedModel):
593
+ def __init__(self, config: TelechatConfig):
594
+ super().__init__(config)
595
+
596
+ self.embed_dim = config.hidden_size
597
+ self.num_heads = config.n_head
598
+ self.config = config
599
+ self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
600
+ if self.config.embed_layernorm:
601
+ self.word_embeddings_layernorm = MixedFusedRMSNorm(self.embed_dim, eps=config.layer_norm_epsilon)
602
+
603
+ self.h = nn.ModuleList([TelechatBlock(config, _) for _ in range(config.num_hidden_layers)])
604
+ self.ln_f = MixedFusedRMSNorm(self.embed_dim, eps=config.layer_norm_epsilon)
605
+ self.gradient_checkpointing = False
606
+ self.post_init()
607
+
608
+ def get_input_embeddings(self):
609
+ return self.word_embeddings
610
+
611
+ def _prepare_attn_mask(
612
+ self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
613
+ ) -> torch.BoolTensor:
614
+ combined_attention_mask = None
615
+ device = attention_mask.device
616
+ _, src_length = input_shape
617
+
618
+ if src_length > 1:
619
+ combined_attention_mask = _make_causal_mask(
620
+ input_shape, device=device, past_key_values_length=past_key_values_length
621
+ )
622
+ expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
623
+ combined_attention_mask = (
624
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
625
+ )
626
+
627
+ return combined_attention_mask
628
+
629
+ def set_input_embeddings(self, new_embeddings: torch.Tensor):
630
+ self.word_embeddings = new_embeddings
631
+
632
+ def forward(
633
+ self,
634
+ input_ids: Optional[torch.LongTensor] = None,
635
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
636
+ attention_mask: Optional[torch.Tensor] = None,
637
+ inputs_embeds: Optional[torch.LongTensor] = None,
638
+ use_cache: Optional[bool] = None,
639
+ output_attentions: Optional[bool] = None,
640
+ output_hidden_states: Optional[bool] = None,
641
+ return_dict: Optional[bool] = None,
642
+ **deprecated_arguments,
643
+ ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
644
+
645
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
646
+ output_hidden_states = (
647
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
648
+ )
649
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
650
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
651
+
652
+ if input_ids is not None:
653
+ batch_size, seq_length = input_ids.shape
654
+ elif inputs_embeds is not None:
655
+ batch_size, seq_length, _ = inputs_embeds.shape
656
+
657
+ if past_key_values is None:
658
+ past_key_values = tuple([None] * len(self.h))
659
+
660
+ if inputs_embeds is None:
661
+ inputs_embeds = self.word_embeddings(input_ids)
662
+ hidden_states = inputs_embeds
663
+
664
+ if self.config.embed_layernorm:
665
+ hidden_states = self.word_embeddings_layernorm(inputs_embeds)
666
+
667
+ presents = () if use_cache else None
668
+ all_self_attentions = () if output_attentions else None
669
+ all_hidden_states = () if output_hidden_states else None
670
+
671
+ if self.gradient_checkpointing and self.training:
672
+ if use_cache:
673
+ use_cache = False
674
+
675
+ seq_length_with_past = seq_length
676
+ past_key_values_length = 0
677
+ if past_key_values[0] is not None:
678
+ past_key_values_length = past_key_values[0][0].shape[2]
679
+ seq_length_with_past = seq_length_with_past + past_key_values_length
680
+ if attention_mask is None:
681
+ attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
682
+ else:
683
+ attention_mask = attention_mask.to(hidden_states.device)
684
+ causal_mask = self._prepare_attn_mask(
685
+ attention_mask,
686
+ input_shape=(batch_size, seq_length),
687
+ past_key_values_length=past_key_values_length,
688
+ )
689
+
690
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
691
+ if output_hidden_states:
692
+ all_hidden_states = all_hidden_states + (hidden_states,)
693
+
694
+ if self.gradient_checkpointing and self.training:
695
+
696
+ def create_custom_forward(module):
697
+ def custom_forward(*inputs):
698
+ # None for past_key_value
699
+ return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
700
+
701
+ return custom_forward
702
+
703
+ outputs = torch.utils.checkpoint.checkpoint(
704
+ create_custom_forward(block),
705
+ hidden_states,
706
+ causal_mask,
707
+ layer_past,
708
+ )
709
+ else:
710
+ outputs = block(
711
+ hidden_states,
712
+ layer_past=layer_past,
713
+ attention_mask=causal_mask,
714
+ use_cache=use_cache,
715
+ output_attentions=output_attentions,
716
+ )
717
+
718
+ hidden_states = outputs[0]
719
+ if use_cache is True:
720
+ presents = presents + (outputs[1],)
721
+
722
+ if output_attentions:
723
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
724
+ hidden_states = self.ln_f(hidden_states)
725
+ if output_hidden_states:
726
+ all_hidden_states = all_hidden_states + (hidden_states,)
727
+ if not return_dict:
728
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
729
+ return BaseModelOutputWithPastAndCrossAttentions(
730
+ last_hidden_state=hidden_states,
731
+ past_key_values=presents,
732
+ hidden_states=all_hidden_states,
733
+ attentions=all_self_attentions,
734
+ )
735
+
736
+
737
+ class TelechatForCausalLM(TelechatPreTrainedModel):
738
+ # _tied_weights_keys = ["lm_head.weight"]
739
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
740
+
741
+ def __init__(self, config: TelechatConfig):
742
+ super().__init__(config)
743
+ self.transformer = TelechatModel(config)
744
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
745
+ self.post_init()
746
+
747
+ def get_output_embeddings(self):
748
+ return self.lm_head
749
+
750
+ def set_output_embeddings(self, new_embeddings: torch.Tensor):
751
+ self.lm_head = new_embeddings
752
+
753
+ def prepare_inputs_for_generation(
754
+ self,
755
+ input_ids: torch.LongTensor,
756
+ past_key_values: Optional[torch.Tensor] = None,
757
+ attention_mask: Optional[torch.Tensor] = None,
758
+ inputs_embeds: Optional[torch.Tensor] = None,
759
+ **kwargs,
760
+ ) -> dict:
761
+ if past_key_values:
762
+ input_ids = input_ids[:, -1].unsqueeze(-1)
763
+ if inputs_embeds is not None and past_key_values is None:
764
+ model_inputs = {"inputs_embeds": inputs_embeds}
765
+ else:
766
+ model_inputs = {"input_ids": input_ids}
767
+
768
+ model_inputs.update(
769
+ {
770
+ "past_key_values": past_key_values,
771
+ "use_cache": kwargs.get("use_cache"),
772
+ "attention_mask": attention_mask,
773
+ }
774
+ )
775
+ return model_inputs
776
+
777
+ def forward(
778
+ self,
779
+ input_ids: Optional[torch.LongTensor] = None,
780
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
781
+ attention_mask: Optional[torch.Tensor] = None,
782
+ inputs_embeds: Optional[torch.Tensor] = None,
783
+ labels: Optional[torch.Tensor] = None,
784
+ use_cache: Optional[bool] = None,
785
+ output_attentions: Optional[bool] = None,
786
+ output_hidden_states: Optional[bool] = None,
787
+ return_dict: Optional[bool] = None,
788
+ **deprecated_arguments,
789
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
790
+
791
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
792
+
793
+ transformer_outputs = self.transformer(
794
+ input_ids,
795
+ past_key_values=past_key_values,
796
+ attention_mask=attention_mask,
797
+ inputs_embeds=inputs_embeds,
798
+ use_cache=use_cache,
799
+ output_attentions=output_attentions,
800
+ output_hidden_states=output_hidden_states,
801
+ return_dict=return_dict,
802
+ )
803
+ hidden_states = transformer_outputs[0]
804
+ lm_logits = self.lm_head(hidden_states)
805
+
806
+ loss = None
807
+ if labels is not None:
808
+ labels = labels.to(lm_logits.device)
809
+ shift_logits = lm_logits[..., :-1, :].contiguous()
810
+ shift_labels = labels[..., 1:].contiguous()
811
+ batch_size, seq_length, vocab_size = shift_logits.shape
812
+ loss_fct = CrossEntropyLoss()
813
+ loss = loss_fct(
814
+ shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
815
+ )
816
+
817
+ if not return_dict:
818
+ output = (lm_logits,) + transformer_outputs[1:]
819
+ return ((loss,) + output) if loss is not None else output
820
+
821
+ return CausalLMOutputWithCrossAttentions(
822
+ loss=loss,
823
+ logits=lm_logits,
824
+ past_key_values=transformer_outputs.past_key_values,
825
+ hidden_states=transformer_outputs.hidden_states,
826
+ attentions=transformer_outputs.attentions,
827
+ )
828
+
829
+ def chat(self, tokenizer, question: str = '', history: Union[List[Dict], History] = None, stream: bool = False,
830
+ generation_config: Optional[GenerationConfig] = None, **kwargs):
831
+ """
832
+ Args:
833
+ tokenizer: the tokenizer of telechat
834
+ question: question which the model reply in this turn
835
+ history: history which will format the input for telechat
836
+ stream: if return the full text at last or yield the text in token
837
+ generation_config: configuration for generation
838
+ **kwargs: args which will update the generation config or pass to model forward
839
+ """
840
+ generation_config = generation_config or self.generation_config
841
+ if not generation_config:
842
+ logger.error("generation_config is None")
843
+ raise ValueError("generation_config must not be None")
844
+ if not question:
845
+ logger.error("question is empty")
846
+ raise ValueError("question must not be empty")
847
+ if history is None:
848
+ history = []
849
+
850
+ # we update and check generate_config here for building inputs.
851
+
852
+ generation_config = copy.deepcopy(generation_config)
853
+ user_id = generation_config.user_token_id
854
+ bot_id = generation_config.bot_token_id
855
+ model_kwargs = generation_config.update(**kwargs)
856
+ generation_config.validate()
857
+
858
+ # transfer to History
859
+ if not isinstance(history, History):
860
+ history = History(tokenizer, history)
861
+
862
+ inputs = self.build_inputs_for_chat(tokenizer, question, history, generation_config, user_id, bot_id)
863
+ history.append({"role": "user", "content": question})
864
+ if stream:
865
+ streamer = TelechatIterTextStreamer(tokenizer, history,skip_prompt=True)
866
+ Thread(target=self.generate, kwargs=dict(
867
+ inputs=inputs.to(self.device), streamer=streamer,
868
+ generation_config=generation_config, **model_kwargs
869
+ )).start()
870
+ return streamer
871
+ else:
872
+ outputs = self.generate(inputs.to(self.device), generation_config=generation_config, **model_kwargs)
873
+ response = tokenizer.decode(outputs[0][len(inputs[0]):-1])
874
+ history.append({"role": "bot", "content": response})
875
+ return response, history
876
+
877
+ def build_inputs_for_chat(self, tokenizer, question, history, generation_config, usr_id, bot_id):
878
+ """
879
+ check history and build inputs here
880
+ """
881
+ # first tokenize question
882
+ q_token = tokenizer(question)
883
+ qa_history = copy.deepcopy(history)
884
+
885
+ # get the max length we should build our inputs in
886
+ model_max_length = self.config.seq_length
887
+ build_max_length = max(0, model_max_length - generation_config.max_new_tokens) \
888
+ if generation_config.max_new_tokens else max(0, generation_config.max_length)
889
+ if build_max_length < 3:
890
+ logger.warning("the model can not meet the requirements of input length,Please check config")
891
+ raise ValueError("")
892
+
893
+ # trunc left
894
+ input_tokens = [usr_id] + q_token["input_ids"][-build_max_length + 1:] + [bot_id]
895
+ length = len(input_tokens)
896
+
897
+ while len(qa_history) != 0:
898
+ message = qa_history.pop()
899
+ if message["role"] == "user":
900
+ tokens = [usr_id] + message["input_ids"]
901
+ elif message["role"] == "bot":
902
+ tokens = [bot_id] + message["input_ids"] + [generation_config.eos_token_id]
903
+ else:
904
+ tokens = []
905
+ if len(tokens) + length >= build_max_length:
906
+ break
907
+ else:
908
+ input_tokens = tokens + input_tokens
909
+
910
+ return torch.tensor([input_tokens], dtype=torch.int64)
pytorch_model.bin.index.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"metadata": {"total_size": 24772864000}, "weight_map": {"lm_head.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.0.input_layernorm.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.0.mlp.down_proj.bias": "pytorch_model_00001-of-00004.bin", "transformer.h.0.mlp.down_proj.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.0.mlp.gate_proj.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.0.mlp.up_proj.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.0.post_attention_layernorm.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.0.self_attention.dense.bias": "pytorch_model_00001-of-00004.bin", "transformer.h.0.self_attention.dense.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.0.self_attention.key_value.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.0.self_attention.query.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.1.input_layernorm.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.1.mlp.down_proj.bias": "pytorch_model_00001-of-00004.bin", "transformer.h.1.mlp.down_proj.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.1.mlp.gate_proj.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.1.mlp.up_proj.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.1.post_attention_layernorm.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.1.self_attention.dense.bias": "pytorch_model_00001-of-00004.bin", "transformer.h.1.self_attention.dense.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.1.self_attention.key_value.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.1.self_attention.query.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.10.input_layernorm.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.10.mlp.down_proj.bias": "pytorch_model_00002-of-00004.bin", "transformer.h.10.mlp.down_proj.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.10.mlp.gate_proj.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.10.mlp.up_proj.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.10.post_attention_layernorm.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.10.self_attention.dense.bias": "pytorch_model_00002-of-00004.bin", "transformer.h.10.self_attention.dense.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.10.self_attention.key_value.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.10.self_attention.query.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.11.input_layernorm.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.11.mlp.down_proj.bias": "pytorch_model_00002-of-00004.bin", "transformer.h.11.mlp.down_proj.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.11.mlp.gate_proj.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.11.mlp.up_proj.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.11.post_attention_layernorm.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.11.self_attention.dense.bias": "pytorch_model_00002-of-00004.bin", "transformer.h.11.self_attention.dense.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.11.self_attention.key_value.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.11.self_attention.query.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.12.input_layernorm.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.12.mlp.down_proj.bias": "pytorch_model_00002-of-00004.bin", "transformer.h.12.mlp.down_proj.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.12.mlp.gate_proj.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.12.mlp.up_proj.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.12.post_attention_layernorm.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.12.self_attention.dense.bias": "pytorch_model_00002-of-00004.bin", "transformer.h.12.self_attention.dense.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.12.self_attention.key_value.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.12.self_attention.query.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.13.input_layernorm.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.13.mlp.down_proj.bias": "pytorch_model_00002-of-00004.bin", "transformer.h.13.mlp.down_proj.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.13.mlp.gate_proj.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.13.mlp.up_proj.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.13.post_attention_layernorm.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.13.self_attention.dense.bias": "pytorch_model_00002-of-00004.bin", "transformer.h.13.self_attention.dense.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.13.self_attention.key_value.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.13.self_attention.query.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.14.input_layernorm.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.14.mlp.down_proj.bias": "pytorch_model_00002-of-00004.bin", "transformer.h.14.mlp.down_proj.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.14.mlp.gate_proj.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.14.mlp.up_proj.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.14.post_attention_layernorm.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.14.self_attention.dense.bias": "pytorch_model_00002-of-00004.bin", "transformer.h.14.self_attention.dense.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.14.self_attention.key_value.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.14.self_attention.query.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.15.input_layernorm.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.15.mlp.down_proj.bias": "pytorch_model_00002-of-00004.bin", "transformer.h.15.mlp.down_proj.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.15.mlp.gate_proj.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.15.mlp.up_proj.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.15.post_attention_layernorm.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.15.self_attention.dense.bias": "pytorch_model_00002-of-00004.bin", "transformer.h.15.self_attention.dense.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.15.self_attention.key_value.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.15.self_attention.query.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.16.input_layernorm.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.16.mlp.down_proj.bias": "pytorch_model_00002-of-00004.bin", "transformer.h.16.mlp.down_proj.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.16.mlp.gate_proj.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.16.mlp.up_proj.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.16.post_attention_layernorm.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.16.self_attention.dense.bias": "pytorch_model_00002-of-00004.bin", "transformer.h.16.self_attention.dense.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.16.self_attention.key_value.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.16.self_attention.query.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.17.input_layernorm.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.17.mlp.down_proj.bias": "pytorch_model_00002-of-00004.bin", "transformer.h.17.mlp.down_proj.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.17.mlp.gate_proj.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.17.mlp.up_proj.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.17.post_attention_layernorm.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.17.self_attention.dense.bias": "pytorch_model_00002-of-00004.bin", "transformer.h.17.self_attention.dense.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.17.self_attention.key_value.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.17.self_attention.query.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.18.input_layernorm.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.18.mlp.down_proj.bias": "pytorch_model_00003-of-00004.bin", "transformer.h.18.mlp.down_proj.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.18.mlp.gate_proj.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.18.mlp.up_proj.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.18.post_attention_layernorm.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.18.self_attention.dense.bias": "pytorch_model_00003-of-00004.bin", "transformer.h.18.self_attention.dense.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.18.self_attention.key_value.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.18.self_attention.query.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.19.input_layernorm.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.19.mlp.down_proj.bias": "pytorch_model_00003-of-00004.bin", "transformer.h.19.mlp.down_proj.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.19.mlp.gate_proj.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.19.mlp.up_proj.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.19.post_attention_layernorm.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.19.self_attention.dense.bias": "pytorch_model_00003-of-00004.bin", "transformer.h.19.self_attention.dense.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.19.self_attention.key_value.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.19.self_attention.query.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.2.input_layernorm.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.2.mlp.down_proj.bias": "pytorch_model_00001-of-00004.bin", "transformer.h.2.mlp.down_proj.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.2.mlp.gate_proj.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.2.mlp.up_proj.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.2.post_attention_layernorm.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.2.self_attention.dense.bias": "pytorch_model_00001-of-00004.bin", "transformer.h.2.self_attention.dense.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.2.self_attention.key_value.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.2.self_attention.query.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.20.input_layernorm.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.20.mlp.down_proj.bias": "pytorch_model_00003-of-00004.bin", "transformer.h.20.mlp.down_proj.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.20.mlp.gate_proj.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.20.mlp.up_proj.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.20.post_attention_layernorm.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.20.self_attention.dense.bias": "pytorch_model_00003-of-00004.bin", "transformer.h.20.self_attention.dense.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.20.self_attention.key_value.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.20.self_attention.query.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.21.input_layernorm.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.21.mlp.down_proj.bias": "pytorch_model_00003-of-00004.bin", "transformer.h.21.mlp.down_proj.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.21.mlp.gate_proj.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.21.mlp.up_proj.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.21.post_attention_layernorm.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.21.self_attention.dense.bias": "pytorch_model_00003-of-00004.bin", "transformer.h.21.self_attention.dense.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.21.self_attention.key_value.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.21.self_attention.query.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.22.input_layernorm.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.22.mlp.down_proj.bias": "pytorch_model_00003-of-00004.bin", "transformer.h.22.mlp.down_proj.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.22.mlp.gate_proj.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.22.mlp.up_proj.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.22.post_attention_layernorm.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.22.self_attention.dense.bias": "pytorch_model_00003-of-00004.bin", "transformer.h.22.self_attention.dense.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.22.self_attention.key_value.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.22.self_attention.query.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.23.input_layernorm.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.23.mlp.down_proj.bias": "pytorch_model_00003-of-00004.bin", "transformer.h.23.mlp.down_proj.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.23.mlp.gate_proj.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.23.mlp.up_proj.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.23.post_attention_layernorm.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.23.self_attention.dense.bias": "pytorch_model_00003-of-00004.bin", "transformer.h.23.self_attention.dense.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.23.self_attention.key_value.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.23.self_attention.query.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.24.input_layernorm.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.24.mlp.down_proj.bias": "pytorch_model_00003-of-00004.bin", "transformer.h.24.mlp.down_proj.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.24.mlp.gate_proj.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.24.mlp.up_proj.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.24.post_attention_layernorm.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.24.self_attention.dense.bias": "pytorch_model_00003-of-00004.bin", "transformer.h.24.self_attention.dense.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.24.self_attention.key_value.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.24.self_attention.query.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.25.input_layernorm.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.25.mlp.down_proj.bias": "pytorch_model_00003-of-00004.bin", "transformer.h.25.mlp.down_proj.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.25.mlp.gate_proj.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.25.mlp.up_proj.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.25.post_attention_layernorm.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.25.self_attention.dense.bias": "pytorch_model_00003-of-00004.bin", "transformer.h.25.self_attention.dense.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.25.self_attention.key_value.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.25.self_attention.query.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.26.input_layernorm.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.26.mlp.down_proj.bias": "pytorch_model_00003-of-00004.bin", "transformer.h.26.mlp.down_proj.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.26.mlp.gate_proj.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.26.mlp.up_proj.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.26.post_attention_layernorm.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.26.self_attention.dense.bias": "pytorch_model_00003-of-00004.bin", "transformer.h.26.self_attention.dense.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.26.self_attention.key_value.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.26.self_attention.query.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.27.input_layernorm.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.27.mlp.down_proj.bias": "pytorch_model_00003-of-00004.bin", "transformer.h.27.mlp.down_proj.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.27.mlp.gate_proj.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.27.mlp.up_proj.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.27.post_attention_layernorm.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.27.self_attention.dense.bias": "pytorch_model_00003-of-00004.bin", "transformer.h.27.self_attention.dense.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.27.self_attention.key_value.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.27.self_attention.query.weight": "pytorch_model_00003-of-00004.bin", "transformer.h.28.input_layernorm.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.28.mlp.down_proj.bias": "pytorch_model_00004-of-00004.bin", "transformer.h.28.mlp.down_proj.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.28.mlp.gate_proj.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.28.mlp.up_proj.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.28.post_attention_layernorm.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.28.self_attention.dense.bias": "pytorch_model_00004-of-00004.bin", "transformer.h.28.self_attention.dense.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.28.self_attention.key_value.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.28.self_attention.query.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.29.input_layernorm.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.29.mlp.down_proj.bias": "pytorch_model_00004-of-00004.bin", "transformer.h.29.mlp.down_proj.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.29.mlp.gate_proj.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.29.mlp.up_proj.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.29.post_attention_layernorm.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.29.self_attention.dense.bias": "pytorch_model_00004-of-00004.bin", "transformer.h.29.self_attention.dense.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.29.self_attention.key_value.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.29.self_attention.query.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.3.input_layernorm.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.3.mlp.down_proj.bias": "pytorch_model_00001-of-00004.bin", "transformer.h.3.mlp.down_proj.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.3.mlp.gate_proj.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.3.mlp.up_proj.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.3.post_attention_layernorm.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.3.self_attention.dense.bias": "pytorch_model_00001-of-00004.bin", "transformer.h.3.self_attention.dense.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.3.self_attention.key_value.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.3.self_attention.query.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.30.input_layernorm.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.30.mlp.down_proj.bias": "pytorch_model_00004-of-00004.bin", "transformer.h.30.mlp.down_proj.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.30.mlp.gate_proj.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.30.mlp.up_proj.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.30.post_attention_layernorm.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.30.self_attention.dense.bias": "pytorch_model_00004-of-00004.bin", "transformer.h.30.self_attention.dense.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.30.self_attention.key_value.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.30.self_attention.query.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.31.input_layernorm.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.31.mlp.down_proj.bias": "pytorch_model_00004-of-00004.bin", "transformer.h.31.mlp.down_proj.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.31.mlp.gate_proj.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.31.mlp.up_proj.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.31.post_attention_layernorm.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.31.self_attention.dense.bias": "pytorch_model_00004-of-00004.bin", "transformer.h.31.self_attention.dense.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.31.self_attention.key_value.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.31.self_attention.query.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.32.input_layernorm.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.32.mlp.down_proj.bias": "pytorch_model_00004-of-00004.bin", "transformer.h.32.mlp.down_proj.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.32.mlp.gate_proj.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.32.mlp.up_proj.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.32.post_attention_layernorm.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.32.self_attention.dense.bias": "pytorch_model_00004-of-00004.bin", "transformer.h.32.self_attention.dense.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.32.self_attention.key_value.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.32.self_attention.query.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.33.input_layernorm.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.33.mlp.down_proj.bias": "pytorch_model_00004-of-00004.bin", "transformer.h.33.mlp.down_proj.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.33.mlp.gate_proj.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.33.mlp.up_proj.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.33.post_attention_layernorm.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.33.self_attention.dense.bias": "pytorch_model_00004-of-00004.bin", "transformer.h.33.self_attention.dense.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.33.self_attention.key_value.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.33.self_attention.query.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.34.input_layernorm.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.34.mlp.down_proj.bias": "pytorch_model_00004-of-00004.bin", "transformer.h.34.mlp.down_proj.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.34.mlp.gate_proj.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.34.mlp.up_proj.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.34.post_attention_layernorm.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.34.self_attention.dense.bias": "pytorch_model_00004-of-00004.bin", "transformer.h.34.self_attention.dense.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.34.self_attention.key_value.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.34.self_attention.query.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.35.input_layernorm.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.35.mlp.down_proj.bias": "pytorch_model_00004-of-00004.bin", "transformer.h.35.mlp.down_proj.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.35.mlp.gate_proj.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.35.mlp.up_proj.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.35.post_attention_layernorm.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.35.self_attention.dense.bias": "pytorch_model_00004-of-00004.bin", "transformer.h.35.self_attention.dense.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.35.self_attention.key_value.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.35.self_attention.query.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.36.input_layernorm.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.36.mlp.down_proj.bias": "pytorch_model_00004-of-00004.bin", "transformer.h.36.mlp.down_proj.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.36.mlp.gate_proj.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.36.mlp.up_proj.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.36.post_attention_layernorm.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.36.self_attention.dense.bias": "pytorch_model_00004-of-00004.bin", "transformer.h.36.self_attention.dense.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.36.self_attention.key_value.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.36.self_attention.query.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.37.input_layernorm.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.37.mlp.down_proj.bias": "pytorch_model_00004-of-00004.bin", "transformer.h.37.mlp.down_proj.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.37.mlp.gate_proj.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.37.mlp.up_proj.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.37.post_attention_layernorm.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.37.self_attention.dense.bias": "pytorch_model_00004-of-00004.bin", "transformer.h.37.self_attention.dense.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.37.self_attention.key_value.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.37.self_attention.query.weight": "pytorch_model_00004-of-00004.bin", "transformer.h.4.input_layernorm.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.4.mlp.down_proj.bias": "pytorch_model_00001-of-00004.bin", "transformer.h.4.mlp.down_proj.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.4.mlp.gate_proj.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.4.mlp.up_proj.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.4.post_attention_layernorm.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.4.self_attention.dense.bias": "pytorch_model_00001-of-00004.bin", "transformer.h.4.self_attention.dense.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.4.self_attention.key_value.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.4.self_attention.query.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.5.input_layernorm.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.5.mlp.down_proj.bias": "pytorch_model_00001-of-00004.bin", "transformer.h.5.mlp.down_proj.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.5.mlp.gate_proj.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.5.mlp.up_proj.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.5.post_attention_layernorm.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.5.self_attention.dense.bias": "pytorch_model_00001-of-00004.bin", "transformer.h.5.self_attention.dense.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.5.self_attention.key_value.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.5.self_attention.query.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.6.input_layernorm.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.6.mlp.down_proj.bias": "pytorch_model_00001-of-00004.bin", "transformer.h.6.mlp.down_proj.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.6.mlp.gate_proj.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.6.mlp.up_proj.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.6.post_attention_layernorm.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.6.self_attention.dense.bias": "pytorch_model_00001-of-00004.bin", "transformer.h.6.self_attention.dense.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.6.self_attention.key_value.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.6.self_attention.query.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.7.input_layernorm.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.7.mlp.down_proj.bias": "pytorch_model_00001-of-00004.bin", "transformer.h.7.mlp.down_proj.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.7.mlp.gate_proj.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.7.mlp.up_proj.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.7.post_attention_layernorm.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.7.self_attention.dense.bias": "pytorch_model_00001-of-00004.bin", "transformer.h.7.self_attention.dense.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.7.self_attention.key_value.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.7.self_attention.query.weight": "pytorch_model_00001-of-00004.bin", "transformer.h.8.input_layernorm.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.8.mlp.down_proj.bias": "pytorch_model_00002-of-00004.bin", "transformer.h.8.mlp.down_proj.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.8.mlp.gate_proj.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.8.mlp.up_proj.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.8.post_attention_layernorm.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.8.self_attention.dense.bias": "pytorch_model_00002-of-00004.bin", "transformer.h.8.self_attention.dense.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.8.self_attention.key_value.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.8.self_attention.query.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.9.input_layernorm.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.9.mlp.down_proj.bias": "pytorch_model_00002-of-00004.bin", "transformer.h.9.mlp.down_proj.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.9.mlp.gate_proj.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.9.mlp.up_proj.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.9.post_attention_layernorm.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.9.self_attention.dense.bias": "pytorch_model_00002-of-00004.bin", "transformer.h.9.self_attention.dense.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.9.self_attention.key_value.weight": "pytorch_model_00002-of-00004.bin", "transformer.h.9.self_attention.query.weight": "pytorch_model_00002-of-00004.bin", "transformer.ln_f.weight": "pytorch_model_00004-of-00004.bin", "transformer.word_embeddings.weight": "pytorch_model_00001-of-00004.bin"}}
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<unk>",
18
+ "lstrip": false,
19
+ "normalized": true,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<unk>",
25
+ "lstrip": false,
26
+ "normalized": true,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
tokenization_telechat.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from shutil import copyfile
3
+ from typing import Any, Dict, List, Optional, Tuple
4
+ import sentencepiece as spm
5
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
6
+ from transformers.utils import logging
7
+
8
+ logger = logging.get_logger(__name__)
9
+
10
+ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
11
+
12
+ # TODO: when we get download url from huggingface, refresh the map
13
+ PRETRAINED_VOCAB_FILES_MAP = {
14
+ "vocab_file": {},
15
+ "tokenizer_file": {},
16
+ }
17
+
18
+
19
+ class TelechatTokenizer(PreTrainedTokenizer):
20
+
21
+ vocab_files_names = VOCAB_FILES_NAMES
22
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
23
+ model_input_names = ["input_ids", "attention_mask"]
24
+
25
+ def __init__(
26
+ self,
27
+ vocab_file,
28
+ unk_token="<unk>",
29
+ bos_token="<_start>",
30
+ eos_token="<_end>",
31
+ pad_token="<_pad>",
32
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
33
+ add_bos_token=True,
34
+ add_eos_token=False,
35
+ clean_up_tokenization_spaces=False,
36
+ **kwargs,
37
+ ):
38
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
39
+ bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
40
+ eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
41
+ unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
42
+ pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
43
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
44
+ self.sp_model.Load(vocab_file)
45
+ super().__init__(
46
+ bos_token=bos_token,
47
+ eos_token=eos_token,
48
+ unk_token=unk_token,
49
+ pad_token=pad_token,
50
+ add_bos_token=add_bos_token,
51
+ add_eos_token=add_eos_token,
52
+ sp_model_kwargs=self.sp_model_kwargs,
53
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
54
+ **kwargs,
55
+ )
56
+ self.vocab_file = vocab_file
57
+ self.add_bos_token = add_bos_token
58
+ self.add_eos_token = add_eos_token
59
+
60
+
61
+ def __getstate__(self):
62
+ state = self.__dict__.copy()
63
+ state["sp_model"] = None
64
+ return state
65
+
66
+ def __setstate__(self, d):
67
+ self.__dict__ = d
68
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
69
+ self.sp_model.Load(self.vocab_file)
70
+
71
+ @property
72
+ def vocab_size(self):
73
+ """Returns vocab size"""
74
+ return self.sp_model.get_piece_size()
75
+
76
+ def get_vocab(self):
77
+ """Returns vocab as a dict"""
78
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
79
+ vocab.update(self.added_tokens_encoder)
80
+ return vocab
81
+
82
+ def _tokenize(self, text):
83
+ """Returns a tokenized string."""
84
+ return self.sp_model.encode(text, out_type=str)
85
+
86
+ def _convert_token_to_id(self, token):
87
+ """Converts a token (str) in an id using the vocab."""
88
+ return self.sp_model.piece_to_id(token)
89
+
90
+ def _convert_id_to_token(self, index):
91
+ """Converts an index (integer) in a token (str) using the vocab."""
92
+ token = self.sp_model.IdToPiece(index)
93
+ return token
94
+
95
+ def convert_tokens_to_string(self, tokens):
96
+ """Converts a sequence of tokens (string) in a single string."""
97
+ current_sub_tokens = []
98
+ out_string = ""
99
+ prev_is_special = False
100
+ for i, token in enumerate(tokens):
101
+ # make sure that special tokens are not decoded using sentencepiece model
102
+ if token in self.all_special_tokens:
103
+ if not prev_is_special and i != 0:
104
+ out_string += " "
105
+ out_string += self.sp_model.decode(current_sub_tokens) + token
106
+ prev_is_special = True
107
+ current_sub_tokens = []
108
+ else:
109
+ current_sub_tokens.append(token)
110
+ prev_is_special = False
111
+ out_string += self.sp_model.decode(current_sub_tokens)
112
+ return out_string
113
+
114
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
115
+ """
116
+ Save the vocabulary and special tokens file to a directory.
117
+
118
+ Args:
119
+ save_directory (`str`):
120
+ The directory in which to save the vocabulary.
121
+
122
+ Returns:
123
+ `Tuple(str)`: Paths to the files saved.
124
+ """
125
+ if not os.path.isdir(save_directory):
126
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
127
+ return
128
+ out_vocab_file = os.path.join(
129
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
130
+ )
131
+
132
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
133
+ copyfile(self.vocab_file, out_vocab_file)
134
+ elif not os.path.isfile(self.vocab_file):
135
+ with open(out_vocab_file, "wb") as fi:
136
+ content_spiece_model = self.sp_model.serialized_model_proto()
137
+ fi.write(content_spiece_model)
138
+
139
+ return (out_vocab_file,)
140
+
141
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
142
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
143
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
144
+
145
+ output = bos_token_id + token_ids_0 + eos_token_id
146
+
147
+ if token_ids_1 is not None:
148
+ output = output + bos_token_id + token_ids_1 + eos_token_id
149
+
150
+ return output
151
+
152
+ def get_special_tokens_mask(
153
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
154
+ ) -> List[int]:
155
+ """
156
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
157
+ special tokens using the tokenizer `prepare_for_model` method.
158
+
159
+ Args:
160
+ token_ids_0 (`List[int]`):
161
+ List of IDs.
162
+ token_ids_1 (`List[int]`, *optional*):
163
+ Optional second list of IDs for sequence pairs.
164
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
165
+ Whether or not the token list is already formatted with special tokens for the model.
166
+
167
+ Returns:
168
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
169
+ """
170
+ if already_has_special_tokens:
171
+ return super().get_special_tokens_mask(
172
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
173
+ )
174
+
175
+ bos_token_id = [1] if self.add_bos_token else []
176
+ eos_token_id = [1] if self.add_eos_token else []
177
+
178
+ if token_ids_1 is None:
179
+ return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
180
+ return (
181
+ bos_token_id
182
+ + ([0] * len(token_ids_0))
183
+ + eos_token_id
184
+ + bos_token_id
185
+ + ([0] * len(token_ids_1))
186
+ + eos_token_id
187
+ )
188
+
189
+ def create_token_type_ids_from_sequences(
190
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
191
+ ) -> List[int]:
192
+ """
193
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
194
+ sequence pair mask has the following format:
195
+
196
+ ```
197
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
198
+ | first sequence | second sequence |
199
+ ```
200
+
201
+ if token_ids_1 is None, only returns the first portion of the mask (0s).
202
+
203
+ Args:
204
+ token_ids_0 (`List[int]`):
205
+ List of ids.
206
+ token_ids_1 (`List[int]`, *optional*):
207
+ Optional second list of IDs for sequence pairs.
208
+
209
+ Returns:
210
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
211
+ """
212
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
213
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
214
+
215
+ output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
216
+
217
+ if token_ids_1 is not None:
218
+ output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
219
+
220
+ return output
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2c86d881f9a94b1c50bf25f8f987accea9ec2a1be74529f0240d8e13e66aa3d
3
+ size 1978781
tokenizer_config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name_or_path": "ChinaTelecom/telechat-12b",
3
+ "tokenizer_class": "TelechatTokenizer",
4
+ "auto_map": {
5
+ "AutoTokenizer": [
6
+ "tokenization_telechat.TelechatTokenizer",
7
+ null
8
+ ]
9
+ },
10
+ "add_bos_token": false,
11
+ "add_eos_token": false,
12
+ "use_fast": false,
13
+ "clean_up_tokenization_spaces": false,
14
+ "eos_token": {
15
+ "__type": "AddedToken",
16
+ "content": "<_end>",
17
+ "lstrip": false,
18
+ "normalized": true,
19
+ "rstrip": false,
20
+ "single_word": true
21
+ },
22
+ "model_max_length": 100000000,
23
+ "sp_model_kwargs": {},
24
+ "pad_token": {
25
+ "__type": "AddedToken",
26
+ "content": "<_pad>",
27
+ "lstrip": false,
28
+ "normalized": true,
29
+ "rstrip": false,
30
+ "single_word": true
31
+ },
32
+ "unk_token": {
33
+ "__type": "AddedToken",
34
+ "content": "<_end>",
35
+ "lstrip": false,
36
+ "normalized": true,
37
+ "rstrip": false,
38
+ "single_word": true
39
+ }
40
+ }