Spaces:
Paused
Paused
Commit
·
4721aa1
1
Parent(s):
0801e4f
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- .github/ISSUE_TEMPLATE/bug_report.yaml +77 -0
- .github/ISSUE_TEMPLATE/feature-request.yaml +34 -0
- .github/PULL_REQUEST_TEMPLATE/pr_template.md +34 -0
- .gitignore +10 -0
- .ipynb_checkpoints/Untitled-checkpoint.ipynb +6 -0
- DEPLOYMENT.md +39 -0
- DEPLOYMENT_en.md +43 -0
- MODEL_LICENSE +65 -0
- PROMPT.md +203 -0
- PROMPT_en.md +201 -0
- README_en.md +250 -0
- Untitled.ipynb +192 -0
- basic_demo/cli_demo.py +79 -0
- basic_demo/cli_demo_bad_word_ids.py +83 -0
- basic_demo/utils.py +58 -0
- basic_demo/web_demo.py +117 -0
- basic_demo/web_demo2.py +89 -0
- chatglm3-6b/.gitattributes +35 -0
- chatglm3-6b/MODEL_LICENSE +65 -0
- chatglm3-6b/README.md +102 -0
- chatglm3-6b/config.json +42 -0
- chatglm3-6b/configuration_chatglm.py +61 -0
- chatglm3-6b/modeling_chatglm.py +1293 -0
- chatglm3-6b/pytorch_model-00001-of-00007.bin +3 -0
- chatglm3-6b/pytorch_model-00002-of-00007.bin +3 -0
- chatglm3-6b/pytorch_model-00003-of-00007.bin +3 -0
- chatglm3-6b/pytorch_model-00004-of-00007.bin +3 -0
- chatglm3-6b/pytorch_model-00005-of-00007.bin +3 -0
- chatglm3-6b/pytorch_model-00006-of-00007.bin +3 -0
- chatglm3-6b/pytorch_model-00007-of-00007.bin +3 -0
- chatglm3-6b/pytorch_model.bin.index.json +207 -0
- chatglm3-6b/quantization.py +188 -0
- chatglm3-6b/tokenization_chatglm.py +300 -0
- chatglm3-6b/tokenizer.model +3 -0
- chatglm3-6b/tokenizer_config.json +12 -0
- composite_demo/.streamlit/config.toml +2 -0
- composite_demo/README.md +85 -0
- composite_demo/README_en.md +85 -0
- composite_demo/assets/demo.png +0 -0
- composite_demo/assets/emojis.png +0 -0
- composite_demo/assets/heart.png +0 -0
- composite_demo/assets/tool.png +0 -0
- composite_demo/client.py +197 -0
- composite_demo/conversation.py +119 -0
- composite_demo/demo_chat.py +80 -0
- composite_demo/demo_ci.py +329 -0
- composite_demo/demo_tool.py +192 -0
- composite_demo/main.py +63 -0
- composite_demo/requirements.txt +12 -0
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
resources/code_en.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
resources/web-demo.gif filter=lfs diff=lfs merge=lfs -text
|
38 |
+
resources/web-demo2.gif filter=lfs diff=lfs merge=lfs -text
|
.github/ISSUE_TEMPLATE/bug_report.yaml
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: "\U0001F41B Bug Report"
|
2 |
+
description: Submit a bug report to help us improve ChatGLM3 / 提交一个 Bug 问题报告来帮助我们改进 ChatGLM3
|
3 |
+
body:
|
4 |
+
- type: textarea
|
5 |
+
id: system-info
|
6 |
+
attributes:
|
7 |
+
label: System Info / 系統信息
|
8 |
+
description: Your operating environment / 您的运行环境信息
|
9 |
+
placeholder: Includes Cuda version, Transformers version, Python version, operating system, hardware information (if you suspect a hardware problem)... / 包括Cuda版本,Transformers版本,Python版本,操作系统,硬件信息(如果您怀疑是硬件方面的问题)...
|
10 |
+
validations:
|
11 |
+
required: true
|
12 |
+
|
13 |
+
- type: textarea
|
14 |
+
id: who-can-help
|
15 |
+
attributes:
|
16 |
+
label: Who can help? / 谁可以帮助到您?
|
17 |
+
description: |
|
18 |
+
Your issue will be replied to more quickly if you can figure out the right person to tag with @
|
19 |
+
All issues are read by one of the maintainers, so if you don't know who to tag, just leave this blank and our maintainer will ping the right person.
|
20 |
+
|
21 |
+
Please tag fewer than 3 people.
|
22 |
+
|
23 |
+
如果您能找到合适的标签 @,您的问题会更快得到回复。
|
24 |
+
所有问题都会由我们的维护者阅读,如果您不知道该标记谁,只需留空,我们的维护人员会找到合适的开发组成员来解决问题。
|
25 |
+
|
26 |
+
标记的人数应该不超过 3 个人。
|
27 |
+
|
28 |
+
Related demo leader / 相关demo负责人 :
|
29 |
+
- finetune_demo: @Btlmd
|
30 |
+
- langchain_demo: @yincf
|
31 |
+
- composite_demo: @abmfy
|
32 |
+
|
33 |
+
If it's not a bug in these three subsections, you may not specify the helper. Our maintainer will find the right person in the development group to solve the problem.
|
34 |
+
|
35 |
+
如果不是这三个子版块的bug,您可以不指明帮助者,我们的维护人员会找到合适的开发组成员来解决问题。
|
36 |
+
|
37 |
+
placeholder: "@Username ..."
|
38 |
+
|
39 |
+
- type: checkboxes
|
40 |
+
id: information-scripts-examples
|
41 |
+
attributes:
|
42 |
+
label: Information / 问题信息
|
43 |
+
description: 'The problem arises when using: / 问题出现在'
|
44 |
+
options:
|
45 |
+
- label: "The official example scripts / 官方的示例脚本"
|
46 |
+
- label: "My own modified scripts / 我自己修改的脚本和任务"
|
47 |
+
|
48 |
+
- type: textarea
|
49 |
+
id: reproduction
|
50 |
+
validations:
|
51 |
+
required: true
|
52 |
+
attributes:
|
53 |
+
label: Reproduction / 复现过程
|
54 |
+
description: |
|
55 |
+
Please provide a code example that reproduces the problem you encountered, preferably with a minimal reproduction unit.
|
56 |
+
If you have code snippets, error messages, stack traces, please provide them here as well.
|
57 |
+
Please format your code correctly using code tags. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
|
58 |
+
Do not use screenshots, as they are difficult to read and (more importantly) do not allow others to copy and paste your code.
|
59 |
+
|
60 |
+
请提供能重现您遇到的问题的代码示例,最好是最小复现单元。
|
61 |
+
如果您有代码片段、错误信息、堆栈跟踪,也请在此提供。
|
62 |
+
请使用代码标签正确格式化您的代码。请参见 https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
|
63 |
+
请勿使用截图,因为截图难以阅读,而且(更重要的是)不允许他人复制粘贴您的代码。
|
64 |
+
placeholder: |
|
65 |
+
Steps to reproduce the behavior/复现Bug的步骤:
|
66 |
+
|
67 |
+
1.
|
68 |
+
2.
|
69 |
+
3.
|
70 |
+
|
71 |
+
- type: textarea
|
72 |
+
id: expected-behavior
|
73 |
+
validations:
|
74 |
+
required: true
|
75 |
+
attributes:
|
76 |
+
label: Expected behavior / 期待表现
|
77 |
+
description: "A clear and concise description of what you would expect to happen. /简单描述您期望发生的事情。"
|
.github/ISSUE_TEMPLATE/feature-request.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: "\U0001F680 Feature request"
|
2 |
+
description: Submit a request for a new ChatGLM3 feature / 提交一个新的 ChatGLM3 的功能建议
|
3 |
+
labels: [ "feature" ]
|
4 |
+
body:
|
5 |
+
- type: textarea
|
6 |
+
id: feature-request
|
7 |
+
validations:
|
8 |
+
required: true
|
9 |
+
attributes:
|
10 |
+
label: Feature request / 功能建议
|
11 |
+
description: |
|
12 |
+
A brief description of the functional proposal. Links to corresponding papers and code are desirable.
|
13 |
+
对功能建议的简述。最好提供对应的论文和代码链接
|
14 |
+
|
15 |
+
- type: textarea
|
16 |
+
id: motivation
|
17 |
+
validations:
|
18 |
+
required: true
|
19 |
+
attributes:
|
20 |
+
label: Motivation / 动机
|
21 |
+
description: |
|
22 |
+
Your motivation for making the suggestion. If that motivation is related to another GitHub issue, link to it here.
|
23 |
+
您提出建议的动机。如果该动机与另一个 GitHub 问题有关,请在此处提供对应的链接。
|
24 |
+
|
25 |
+
- type: textarea
|
26 |
+
id: contribution
|
27 |
+
validations:
|
28 |
+
required: true
|
29 |
+
attributes:
|
30 |
+
label: Your contribution / 您的贡献
|
31 |
+
description: |
|
32 |
+
|
33 |
+
Your PR link or any other link you can help with.
|
34 |
+
您的PR链接或者其他您能提供帮助的链接。
|
.github/PULL_REQUEST_TEMPLATE/pr_template.md
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Raise valuable PR / 提出有价值的PR
|
2 |
+
|
3 |
+
## Caution/ 注意事项:
|
4 |
+
Users should keep the following points in mind when submitting PRs:
|
5 |
+
|
6 |
+
1. The proposed PR should be about this project.
|
7 |
+
2. the proposed PR should be relevant, if there are multiple ideas and optimizations, they should be assigned to different PRs.
|
8 |
+
|
9 |
+
用户在提交PR时候应该注意以下几点:
|
10 |
+
|
11 |
+
1. 提出的PR应该是关于本项目的。
|
12 |
+
2. 提出的PR应该具有针对性,如果具有多个不同的想法和优化方案,应该分配到不同的PR中。
|
13 |
+
|
14 |
+
## 不应该提出的PR / PRs that should not be proposed
|
15 |
+
|
16 |
+
If a developer proposes a PR about any of the following, it may be closed or Rejected.
|
17 |
+
|
18 |
+
1. those that don't describe improvement options.
|
19 |
+
2. multiple issues of different types combined in one PR.
|
20 |
+
3. The proposed PR is highly duplicative of already existing PRs.
|
21 |
+
|
22 |
+
如果开发者提出关于以下方面的PR,则可能会被直接关闭或拒绝通过。
|
23 |
+
|
24 |
+
1. 没有说明改进方案的。
|
25 |
+
2. 多个不同类型的问题合并在一个PR中的。
|
26 |
+
3. 提出的PR与已经存在的PR高度重复的。
|
27 |
+
|
28 |
+
|
29 |
+
# 检查您的PR
|
30 |
+
- [ ] Have you read the Contributor Guidelines, Pull Request section? / 您是否阅读了贡献者指南、Pull Request 部分?
|
31 |
+
- [ ] Has this been discussed/approved via a Github issue or forum? If so, add a link. / 是否通过 Github 问题或论坛讨论/批准过?如果是,请添加链接。
|
32 |
+
- [ ] Did you make sure you updated the documentation with your changes? Here are the Documentation Guidelines, and here are the Documentation Formatting Tips. /您是否确保根据您的更改更新了文档?这里是文档指南,这里是文档格式化技巧。
|
33 |
+
- [ ] Did you write new required tests? / 您是否编写了新的必要测试?
|
34 |
+
- [ ] Are your PRs for only one issue / 您的PR是否仅针对一个问题
|
.gitignore
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
|
3 |
+
# finetune_demo: generated & downloaded files
|
4 |
+
finetune_demo/output
|
5 |
+
finetune_demo/data
|
6 |
+
finetune_demo/formatted_data
|
7 |
+
ToolAlpaca/
|
8 |
+
AdvertiseGen/
|
9 |
+
*.gz
|
10 |
+
*.idea
|
.ipynb_checkpoints/Untitled-checkpoint.ipynb
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [],
|
3 |
+
"metadata": {},
|
4 |
+
"nbformat": 4,
|
5 |
+
"nbformat_minor": 5
|
6 |
+
}
|
DEPLOYMENT.md
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## 低成本部署
|
2 |
+
|
3 |
+
### 模型量化
|
4 |
+
|
5 |
+
默认情况下,模型以 FP16 精度加载,运行上述代码需要大概 13GB 显存。如果你的 GPU 显存有限,可以尝试以量化方式加载模型,使用方法如下:
|
6 |
+
|
7 |
+
```python
|
8 |
+
model = AutoModel.from_pretrained("THUDM/chatglm3-6b",trust_remote_code=True).quantize(4).cuda()
|
9 |
+
```
|
10 |
+
|
11 |
+
模型量化会带来一定的性能损失,经过测试,ChatGLM3-6B 在 4-bit 量化下仍然能够进行自然流畅的生成。
|
12 |
+
|
13 |
+
### CPU 部署
|
14 |
+
|
15 |
+
如果你没有 GPU 硬件的话,也可以在 CPU 上进行推理,但是推理速度会更慢。使用方法如下(需要大概 32GB 内存)
|
16 |
+
```python
|
17 |
+
model = AutoModel.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True).float()
|
18 |
+
```
|
19 |
+
|
20 |
+
### Mac 部署
|
21 |
+
|
22 |
+
对于搭载了 Apple Silicon 或者 AMD GPU 的 Mac,可以使用 MPS 后端来在 GPU 上运行 ChatGLM3-6B。需要参考 Apple 的 [官方说明](https://developer.apple.com/metal/pytorch) 安装 PyTorch-Nightly(正确的版本号应该是2.x.x.dev2023xxxx,而不是 2.x.x)。
|
23 |
+
|
24 |
+
目前在 MacOS 上只支持[从本地加载模型](README.md#从本地加载模型)。将代码中的模型加载改为从本地加载,并使用 mps 后端:
|
25 |
+
```python
|
26 |
+
model = AutoModel.from_pretrained("your local path", trust_remote_code=True).to('mps')
|
27 |
+
```
|
28 |
+
|
29 |
+
加载半精度的 ChatGLM3-6B 模型需要大概 13GB 内存。内存较小的机器(比如 16GB 内存的 MacBook Pro),在空余内存不足的情况下会使用硬盘上的虚拟内存,导致推理速度严重变慢。
|
30 |
+
|
31 |
+
### 多卡部署
|
32 |
+
如果你有多张 GPU,但是每张 GPU 的显存大小都不足以容纳完整的模型,那么可以将模型切分在多张GPU上。首先安装 accelerate: `pip install accelerate`,然后通过如下方法加载模型:
|
33 |
+
|
34 |
+
```python
|
35 |
+
from utils import load_model_on_gpus
|
36 |
+
|
37 |
+
model = load_model_on_gpus("THUDM/chatglm3-6b", num_gpus=2)
|
38 |
+
```
|
39 |
+
即可将模型部署到两张 GPU 上进行推理。你可以将 `num_gpus` 改为你希望使用的 GPU 数。默认是均匀切分的,你也可以传入 `device_map` 参数来自己指定。
|
DEPLOYMENT_en.md
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Low-Cost Deployment
|
2 |
+
|
3 |
+
### Model Quantization
|
4 |
+
|
5 |
+
By default, the model is loaded with FP16 precision, running the above code requires about 13GB of VRAM. If your GPU's VRAM is limited, you can try loading the model quantitatively, as follows:
|
6 |
+
|
7 |
+
```python
|
8 |
+
model = AutoModel.from_pretrained("THUDM/chatglm3-6b",trust_remote_code=True).quantize(4).cuda()
|
9 |
+
```
|
10 |
+
|
11 |
+
Model quantization will bring some performance loss. Through testing, ChatGLM3-6B can still perform natural and smooth generation under 4-bit quantization.
|
12 |
+
|
13 |
+
### CPU Deployment
|
14 |
+
|
15 |
+
If you don't have GPU hardware, you can also run inference on the CPU, but the inference speed will be slower. The usage is as follows (requires about 32GB of memory):
|
16 |
+
|
17 |
+
```python
|
18 |
+
model = AutoModel.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True).float()
|
19 |
+
```
|
20 |
+
|
21 |
+
### Mac Deployment
|
22 |
+
|
23 |
+
For Macs equipped with Apple Silicon or AMD GPUs, the MPS backend can be used to run ChatGLM3-6B on the GPU. Refer to Apple's [official instructions](https://developer.apple.com/metal/pytorch) to install PyTorch-Nightly (the correct version number should be 2.x.x.dev2023xxxx, not 2.x.x).
|
24 |
+
|
25 |
+
Currently, only [loading the model locally](README_en.md#load-model-locally) is supported on MacOS. Change the model loading in the code to load locally and use the MPS backend:
|
26 |
+
|
27 |
+
```python
|
28 |
+
model = AutoModel.from_pretrained("your local path", trust_remote_code=True).to('mps')
|
29 |
+
```
|
30 |
+
|
31 |
+
Loading the half-precision ChatGLM3-6B model requires about 13GB of memory. Machines with smaller memory (such as a 16GB memory MacBook Pro) will use virtual memory on the hard disk when there is insufficient free memory, resulting in a significant slowdown in inference speed.
|
32 |
+
|
33 |
+
### Multi-GPU Deployment
|
34 |
+
|
35 |
+
If you have multiple GPUs, but each GPU's VRAM size is not enough to accommodate the complete model, then the model can be split across multiple GPUs. First, install accelerate: `pip install accelerate`, and then load the model through the following methods:
|
36 |
+
|
37 |
+
```python
|
38 |
+
from utils import load_model_on_gpus
|
39 |
+
|
40 |
+
model = load_model_on_gpus("THUDM/chatglm3-6b", num_gpus=2)
|
41 |
+
```
|
42 |
+
|
43 |
+
This allows the model to be deployed on two GPUs for inference. You can change `num_gpus` to the number of GPUs you want to use. It is evenly split by default, but you can also pass the `device_map` parameter to specify it yourself.
|
MODEL_LICENSE
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
The ChatGLM3-6B License
|
2 |
+
|
3 |
+
1. 定义
|
4 |
+
|
5 |
+
“许可方”是指分发其软件的 ChatGLM3-6B 模型团队。
|
6 |
+
|
7 |
+
“软件”是指根据本许可提供的 ChatGLM3-6B 模型参数。
|
8 |
+
|
9 |
+
2. 许可授予
|
10 |
+
|
11 |
+
根据本许可的条款和条件,许可方特此授予您非排他性、全球性、不可转让、不可再许可、可撤销、免版税的版权许可。
|
12 |
+
|
13 |
+
上述版权声明和本许可声明应包含在本软件的所有副本或重要部分中。
|
14 |
+
|
15 |
+
3.限制
|
16 |
+
|
17 |
+
您不得出于任何军事或非法目的使用、复制、修改、合并、发布、分发、复制或创建本软件的全部或部分衍生作品。
|
18 |
+
|
19 |
+
您不得利用本软件从事任何危害国家安全和国家统一、危害社会公共利益、侵犯人身权益的行为。
|
20 |
+
|
21 |
+
4.免责声明
|
22 |
+
|
23 |
+
本软件“按原样”提供,不提供任何明示或暗示的保证,包括但不限于对适销性、特定用途的适用性和非侵权性的保证。 在任何情况下,作者或版权持有人均不对任何索赔、损害或其他责任负责,无论是在合同诉讼、侵权行为还是其他方面,由软件或软件的使用或其他交易引起、由软件引起或与之相关 软件。
|
24 |
+
|
25 |
+
5. 责任限制
|
26 |
+
|
27 |
+
除适用法律禁止的范围外,在任何情况下且根据任何法律理论,无论是基于侵权行为、疏忽、合同、责任或其他原因,任何许可方均不对您承担任何直接、间接、特殊、偶然、示范性、 或间接损害,或任何其他商业损失,即使许可人已被告知此类损害的可能性。
|
28 |
+
|
29 |
+
6.争议解决
|
30 |
+
|
31 |
+
本许可受中华人民共和国法律管辖并按其解释。 因本许可引起的或与本许可有关的任何争议应提交北京市海淀区人民法院。
|
32 |
+
|
33 |
+
请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 [email protected] 与我们联系。
|
34 |
+
|
35 |
+
1. Definitions
|
36 |
+
|
37 |
+
“Licensor” means the ChatGLM3-6B Model Team that distributes its Software.
|
38 |
+
|
39 |
+
“Software” means the ChatGLM3-6B model parameters made available under this license.
|
40 |
+
|
41 |
+
2. License Grant
|
42 |
+
|
43 |
+
Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software.
|
44 |
+
|
45 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
46 |
+
|
47 |
+
3. Restriction
|
48 |
+
|
49 |
+
You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any military, or illegal purposes.
|
50 |
+
|
51 |
+
You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings.
|
52 |
+
|
53 |
+
4. Disclaimer
|
54 |
+
|
55 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
56 |
+
|
57 |
+
5. Limitation of Liability
|
58 |
+
|
59 |
+
EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
|
60 |
+
|
61 |
+
6. Dispute Resolution
|
62 |
+
|
63 |
+
This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing.
|
64 |
+
|
65 |
+
Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at [email protected].
|
PROMPT.md
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## ChatGLM3 对话格式
|
2 |
+
为了避免用户输入的注入攻击,以及统一 Code Interpreter,Tool & Agent 等任务的输入,ChatGLM3 采用了全新的对话格式。
|
3 |
+
|
4 |
+
### 规定
|
5 |
+
#### 整体结构
|
6 |
+
ChatGLM3 对话的格式由若干对话组成,其中每个对话包含对话头和内容,一个典型的多轮对话结构如下
|
7 |
+
```text
|
8 |
+
<|system|>
|
9 |
+
You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
|
10 |
+
<|user|>
|
11 |
+
Hello
|
12 |
+
<|assistant|>
|
13 |
+
Hello, I'm ChatGLM3. What can I assist you today?
|
14 |
+
```
|
15 |
+
**实际中每轮对话内容并不一定以换行符结尾,这里只是为了美观,下同**
|
16 |
+
|
17 |
+
#### 对话头
|
18 |
+
对话头占完整的一行,格式为
|
19 |
+
```text
|
20 |
+
<|role|>{metadata}
|
21 |
+
```
|
22 |
+
其中 `<|role|>` 部分使用 special token 表示,无法从文本形式被 tokenizer 编码以防止注入。metadata 部分采用纯文本表示,为可选内容。
|
23 |
+
* `<|system|>`:系统信息,设计上可穿插于对话中,**但目前规定仅可以出现在开头**
|
24 |
+
* `<|user|>`:用户
|
25 |
+
- 不会连续出现多个来自 `<|user|>` 的信息
|
26 |
+
* `<|assistant|>`:AI 助手
|
27 |
+
- 在出现之前必须有一个来自 `<|user|>` 的信息
|
28 |
+
* `<|observation|>`:外部的返回结果
|
29 |
+
- 必须在 `<|assistant|>` 的信息之后
|
30 |
+
|
31 |
+
### 样例场景
|
32 |
+
|
33 |
+
为提升可读性,下列样例场景中表示角色的 special token 前均额外添加了一个换行符。实际使用及 tokenizer 实现中均无需额外添加这一换行。
|
34 |
+
|
35 |
+
#### 多轮对话
|
36 |
+
* 有且仅有 `<|user|>`、`<|assistant|>`、`<|system|>` 三种 role
|
37 |
+
```text
|
38 |
+
<|system|>
|
39 |
+
You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
|
40 |
+
<|user|>
|
41 |
+
Hello
|
42 |
+
<|assistant|>
|
43 |
+
Hello, I'm ChatGLM3. What can I assist you today?
|
44 |
+
```
|
45 |
+
|
46 |
+
#### 工具调用
|
47 |
+
````
|
48 |
+
<|system|>
|
49 |
+
Answer the following questions as best as you can. You have access to the following tools:
|
50 |
+
[
|
51 |
+
{
|
52 |
+
"name": "get_current_weather",
|
53 |
+
"description": "Get the current weather in a given location",
|
54 |
+
"parameters": {
|
55 |
+
"type": "object",
|
56 |
+
"properties": {
|
57 |
+
"location": {
|
58 |
+
"type": "string",
|
59 |
+
"description": "The city and state, e.g. San Francisco, CA",
|
60 |
+
},
|
61 |
+
"unit": {"type": "string"},
|
62 |
+
},
|
63 |
+
"required": ["location"],
|
64 |
+
},
|
65 |
+
}
|
66 |
+
]
|
67 |
+
<|user|>
|
68 |
+
今天北京的天气怎么样?
|
69 |
+
<|assistant|>
|
70 |
+
好的,让我们来查看今天的天气
|
71 |
+
<|assistant|>get_current_weather
|
72 |
+
```python
|
73 |
+
tool_call(location="beijing", unit="celsius")
|
74 |
+
```
|
75 |
+
<|observation|>
|
76 |
+
{"temperature": 22}
|
77 |
+
<|assistant|>
|
78 |
+
根据查询结果,今天北京的气温为 22 摄氏度。
|
79 |
+
````
|
80 |
+
|
81 |
+
#### 代码执行
|
82 |
+
* 有 `<|user|>`、`<|assistant|>`、`<|system|>`、`<|observation|>` 四种 role。其中 `<|assistant|>` 的 metadata 只有 interpreter。
|
83 |
+
|
84 |
+
`````text
|
85 |
+
<|system|>
|
86 |
+
你是一位智能AI助手,你叫ChatGLM3,你连接着一台电脑,但请注意不能联网。在使用Python解决任务时,你可以运行代码并得到结果,如果运行结果有错误,你需要尽可能对代码进行改进。你可以处理用户上传到电脑上的文件,文件默认存储路径是/mnt/data/。
|
87 |
+
<|user|>
|
88 |
+
#File: /mnt/data/metadata.jsonl
|
89 |
+
#Size: 35380
|
90 |
+
#File uploaded
|
91 |
+
文件中是否存在缺失值或异常值?
|
92 |
+
<|assistant|>
|
93 |
+
首先,我将读取您提供的文件,并查看其内容以确定是否存在缺失值或异常值。
|
94 |
+
|
95 |
+
我们从加载文件开始:
|
96 |
+
<|assistant|>interpreter
|
97 |
+
```python
|
98 |
+
import json
|
99 |
+
|
100 |
+
# Load the uploaded file
|
101 |
+
data = []
|
102 |
+
with open("/mnt/data/metadata.jsonl", 'r') as file:
|
103 |
+
for line in file:
|
104 |
+
data.append(json.loads(line))
|
105 |
+
|
106 |
+
# Display the first few records to get a sense of the data
|
107 |
+
data[:5]
|
108 |
+
```
|
109 |
+
<|observation|>
|
110 |
+
```result
|
111 |
+
[{'file_name': 'aceinthehole.png',
|
112 |
+
'name': 'Ace in the Hole',
|
113 |
+
'type': 'survivor',
|
114 |
+
'description': 'Lady Luck always seems to be throwing something good your way.'},
|
115 |
+
{'file_name': 'adrenaline.png',
|
116 |
+
'name': 'Adrenaline',
|
117 |
+
'type': 'survivor',
|
118 |
+
'description': 'You are fuelled by unexpected energy when on the verge of escape.'},
|
119 |
+
{'file_name': 'aftercare.png',
|
120 |
+
'name': 'Aftercare',
|
121 |
+
'type': 'survivor',
|
122 |
+
'description': "Unlocks potential in one's aura reading ability."},
|
123 |
+
{'file_name': 'agitation.png',
|
124 |
+
'name': 'Agitation',
|
125 |
+
'type': 'killer',
|
126 |
+
'description': 'You get excited in anticipation of hooking your prey.'},
|
127 |
+
{'file_name': 'alert.png',
|
128 |
+
'name': 'Alert',
|
129 |
+
'type': 'survivor',
|
130 |
+
'description': 'Your acute senses are on high alert.'}]
|
131 |
+
```
|
132 |
+
<|assistant|>
|
133 |
+
该文件看起来包含有关某些条目的元数据,每个条目有以下字段:
|
134 |
+
|
135 |
+
|
136 |
+
- `file_name`: 文件名称
|
137 |
+
- `name`: 名称
|
138 |
+
- `type`: 类型(例如 "survivor" 或 "killer")
|
139 |
+
- `description`: 描述
|
140 |
+
|
141 |
+
接下来,我将检查这些字段是否存在缺失值。
|
142 |
+
<|assistant|>interpreter
|
143 |
+
```python
|
144 |
+
# Check for missing values in each column
|
145 |
+
missing_values = {key: sum(1 for item in data if key not in item) for key in data[0].keys()}
|
146 |
+
|
147 |
+
missing_values
|
148 |
+
```
|
149 |
+
<|observation|>
|
150 |
+
```result
|
151 |
+
{'file_name': 0, 'name': 0, 'type': 0, 'description': 0}
|
152 |
+
```
|
153 |
+
<|assistant|>
|
154 |
+
每个字段都没有缺失值。这意味着每个条目都完整。
|
155 |
+
|
156 |
+
接下来,我们可以检查是否存在任何潜在的异常值。这通常需要更深入的分析和对数据的了解,但我可以从一些基本的统计和分布开始进行检查。为了这样做,我们可以查看“type”字段的值分布,看看是否有任何意外的类别。
|
157 |
+
<|assistant|>interpreter
|
158 |
+
```python
|
159 |
+
# Check the distribution of values in the 'type' column
|
160 |
+
type_distribution = {key: sum(1 for item in data if item['type'] == key) for key in set(item['type'] for item in data)}
|
161 |
+
|
162 |
+
type_distribution
|
163 |
+
```
|
164 |
+
<|observation|>
|
165 |
+
```result
|
166 |
+
{'survivor': 116, 'killer': 103}
|
167 |
+
```
|
168 |
+
<|assistant|>
|
169 |
+
“type”字段有两个值:“survivor”和“killer”。其中,“survivor”有116个条目,而“killer”有103个条目。
|
170 |
+
|
171 |
+
从这个初步的检查中,我们没有发现任何明显的异常值。
|
172 |
+
|
173 |
+
如果您有任何特定的标准或其他字段需要检查,我可以为您提供更详细的分析。您还希望我检查什么呢?
|
174 |
+
<|user|>
|
175 |
+
帮我画一个爱心
|
176 |
+
<|assistant|>interpreter
|
177 |
+
```python
|
178 |
+
import numpy as np
|
179 |
+
import matplotlib.pyplot as plt
|
180 |
+
|
181 |
+
# 爱心形状的参数方程
|
182 |
+
def heart(t):
|
183 |
+
x = 16 * np.sin(t) ** 3
|
184 |
+
y = 13 * np.cos(t) - 5 * np.cos(2 * t) - 2 * np.cos(3 * t) - np.cos(4 * t)
|
185 |
+
return x, y
|
186 |
+
|
187 |
+
t = np.linspace(0, 2 * np.pi, 1000)
|
188 |
+
x, y = heart(t)
|
189 |
+
|
190 |
+
plt.figure(figsize=(6, 6))
|
191 |
+
plt.plot(x, y, color='red')
|
192 |
+
plt.axis('equal')
|
193 |
+
plt.axis('off')
|
194 |
+
plt.show()
|
195 |
+
```
|
196 |
+
<|observation|>
|
197 |
+
```result
|
198 |
+
[Image]
|
199 |
+
```
|
200 |
+
<|assistant|>
|
201 |
+
这是一个爱心形状。我使用了参数方程来描述这个形状,并使用matplotlib进行了绘制。如果您有任何其他需求或问题,请随时告诉我。
|
202 |
+
<|user|> # End
|
203 |
+
````
|
PROMPT_en.md
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## ChatGLM3 Chat Format
|
2 |
+
To avoid injection attacks from user input, and to unify the input of Code Interpreter, Tool & Agent and other tasks, ChatGLM3 adopts a brand-new dialogue format.
|
3 |
+
|
4 |
+
### Regulations
|
5 |
+
#### Overall Structure
|
6 |
+
The format of the ChatGLM3 dialogue consists of several conversations, each of which contains a dialogue header and content. A typical multi-turn dialogue structure is as follows:
|
7 |
+
```text
|
8 |
+
<|system|>
|
9 |
+
You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
|
10 |
+
<|user|>
|
11 |
+
Hello
|
12 |
+
<|assistant|>
|
13 |
+
Hello, I'm ChatGLM3. What can I assist you today?
|
14 |
+
```
|
15 |
+
|
16 |
+
#### Chat Header
|
17 |
+
The chat header occupies a complete line, formatted as:
|
18 |
+
```text
|
19 |
+
<|role|>{metadata}
|
20 |
+
```
|
21 |
+
Where `<|role|>` part is represented in a special token, which can’t be encoded by the tokenizer from the text form to prevent injection attacks. The `metadata` part is represented in plain texts and is optional content.
|
22 |
+
* `<|system|>`: System information, which can be interspersed in the dialogue in design, **but currently only appears at the beginning**
|
23 |
+
* `<|user|>`: User
|
24 |
+
- Multiple messages from `<|user|>` will not appear continuously
|
25 |
+
* `<|assistant|>`: AI assistant
|
26 |
+
- There must be a message from `<|user|>` before it appears
|
27 |
+
* `<|observation|>`: External return result
|
28 |
+
- Must be after the message from `<|assistant|>`
|
29 |
+
|
30 |
+
### Example Scenarios
|
31 |
+
|
32 |
+
For better readablity, an extra `\n` is added before each role special token. This extra `\n` should not be added in actual use and tokenizer implementation.
|
33 |
+
|
34 |
+
#### Multi-turn Dialogue
|
35 |
+
* There are only three roles: `<|user|>`, `<|assistant|>`, and `<|system|>`.
|
36 |
+
```text
|
37 |
+
<|system|>
|
38 |
+
You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
|
39 |
+
<|user|>
|
40 |
+
Hello
|
41 |
+
<|assistant|>
|
42 |
+
Hello, I'm ChatGLM3. What can I assist you today?
|
43 |
+
```
|
44 |
+
|
45 |
+
#### Tool Calling
|
46 |
+
````
|
47 |
+
<|system|>
|
48 |
+
Answer the following questions as best as you can. You have access to the following tools:
|
49 |
+
[
|
50 |
+
{
|
51 |
+
"name": "get_current_weather",
|
52 |
+
"description": "Get the current weather in a given location",
|
53 |
+
"parameters": {
|
54 |
+
"type": "object",
|
55 |
+
"properties": {
|
56 |
+
"location": {
|
57 |
+
"type": "string",
|
58 |
+
"description": "The city and state, e.g. San Francisco, CA",
|
59 |
+
},
|
60 |
+
"unit": {"type": "string"},
|
61 |
+
},
|
62 |
+
"required": ["location"],
|
63 |
+
},
|
64 |
+
}
|
65 |
+
]
|
66 |
+
<|user|>
|
67 |
+
What's the weather in Beijing today?
|
68 |
+
<|assistant|>
|
69 |
+
Okay, let's look up the weather in Bejing today.
|
70 |
+
<|assistant|>get_current_weather
|
71 |
+
```python
|
72 |
+
tool_call(location="beijing", unit="celsius")
|
73 |
+
```
|
74 |
+
<|observation|>
|
75 |
+
{"temperature": 22}
|
76 |
+
<|assistant|>
|
77 |
+
According to the query results, the temperature in Beijing today is 22 degrees Celsius.
|
78 |
+
````
|
79 |
+
|
80 |
+
#### Code Execution
|
81 |
+
* Four available roles: `<|user|>`, `<|assistant|>`, `<|system|>`, and `<|observation|>`, where `metadata` of `<|assistant|>` can only be interpreter.
|
82 |
+
|
83 |
+
`````text
|
84 |
+
<|system|>
|
85 |
+
You are an intelligent AI assistant, named ChatGLM3. You are connected to a computer, but please note that you cannot connect to the internet. When using Python to solve tasks, you can run the code and get the results. If there are errors in the running results, you need to improve the code as much as possible. You can process files uploaded by users to your computer, and the default storage path for files is /mnt/data/.
|
86 |
+
<|user|>
|
87 |
+
#File: /mnt/data/metadata.jsonl
|
88 |
+
#Size: 35380
|
89 |
+
#File uploaded
|
90 |
+
Are there missing or abnormal values in the file?
|
91 |
+
<|assistant|>
|
92 |
+
Firstly, I will read the file you provided and review its contents to determine if there are missing or abnormal values.
|
93 |
+
|
94 |
+
We start by loading the file:
|
95 |
+
<|assistant|>interpreter
|
96 |
+
```python
|
97 |
+
import json
|
98 |
+
|
99 |
+
# Load the uploaded file
|
100 |
+
data = []
|
101 |
+
with open("/mnt/data/metadata.jsonl", 'r') as file:
|
102 |
+
for line in file:
|
103 |
+
data.append(json.loads(line))
|
104 |
+
|
105 |
+
# Display the first few records to get a sense of the data
|
106 |
+
data[:5]
|
107 |
+
```
|
108 |
+
<|observation|>
|
109 |
+
```result
|
110 |
+
[{'file_name': 'aceinthehole.png',
|
111 |
+
'name': 'Ace in the Hole',
|
112 |
+
'type': 'survivor',
|
113 |
+
'description': 'Lady Luck always seems to be throwing something good your way.'},
|
114 |
+
{'file_name': 'adrenaline.png',
|
115 |
+
'name': 'Adrenaline',
|
116 |
+
'type': 'survivor',
|
117 |
+
'description': 'You are fuelled by unexpected energy when on the verge of escape.'},
|
118 |
+
{'file_name': 'aftercare.png',
|
119 |
+
'name': 'Aftercare',
|
120 |
+
'type': 'survivor',
|
121 |
+
'description': "Unlocks potential in one's aura reading ability."},
|
122 |
+
{'file_name': 'agitation.png',
|
123 |
+
'name': 'Agitation',
|
124 |
+
'type': 'killer',
|
125 |
+
'description': 'You get excited in anticipation of hooking your prey.'},
|
126 |
+
{'file_name': 'alert.png',
|
127 |
+
'name': 'Alert',
|
128 |
+
'type': 'survivor',
|
129 |
+
'description': 'Your acute senses are on high alert.'}]
|
130 |
+
```
|
131 |
+
<|assistant|>
|
132 |
+
This file appears to contain metadata about certain items, each with the following fields:
|
133 |
+
|
134 |
+
- `file_name`
|
135 |
+
- `name`
|
136 |
+
- `type` (e.g. "survivor" or "killer")
|
137 |
+
- `description`
|
138 |
+
|
139 |
+
Next, I will check for missing values in these fields.
|
140 |
+
<|assistant|>interpreter
|
141 |
+
```python
|
142 |
+
# Check for missing values in each column
|
143 |
+
missing_values = {key: sum(1 for item in data if key not in item) for key in data[0].keys()}
|
144 |
+
|
145 |
+
missing_values
|
146 |
+
```
|
147 |
+
<|observation|>
|
148 |
+
```result
|
149 |
+
{'file_name': 0, 'name': 0, 'type': 0, 'description': 0}
|
150 |
+
```
|
151 |
+
<|assistant|>
|
152 |
+
Each field has no missing values. This means that each entry is complete.
|
153 |
+
|
154 |
+
Next, we can check for any potential outliers. This usually requires deeper analysis and understanding of the data, but I can start by examining some basic statistics and distributions. To do this, we can check the value distribution of the "type" field to see if there are any unexpected categories.
|
155 |
+
<|assistant|>interpreter
|
156 |
+
```python
|
157 |
+
# Check the distribution of values in the 'type' column
|
158 |
+
type_distribution = {key: sum(1 for item in data if item['type'] == key) for key in set(item['type'] for item in data)}
|
159 |
+
|
160 |
+
type_distribution
|
161 |
+
```
|
162 |
+
<|observation|>
|
163 |
+
```result
|
164 |
+
{'survivor': 116, 'killer': 103}
|
165 |
+
```
|
166 |
+
<|assistant|>
|
167 |
+
The 'type' field has two values: 'survivor' and 'killer'. Among them, 'survivor' has 116 entries, while 'killer' has 103 entries.
|
168 |
+
|
169 |
+
From this preliminary examination, we did not find any obvious abnormal values.
|
170 |
+
|
171 |
+
If you have any specific standards or other fields that need to be checked, I can provide you with a more detailed analysis. What else do you want me to check?
|
172 |
+
<|user|>
|
173 |
+
Help me draw a heart
|
174 |
+
<|assistant|>interpreter
|
175 |
+
```python
|
176 |
+
import numpy as np
|
177 |
+
import matplotlib.pyplot as plt
|
178 |
+
|
179 |
+
# The parametric equation for the shape of a heart
|
180 |
+
def heart(t):
|
181 |
+
x = 16 * np.sin(t) ** 3
|
182 |
+
y = 13 * np.cos(t) - 5 * np.cos(2 * t) - 2 * np.cos(3 * t) - np.cos(4 * t)
|
183 |
+
return x, y
|
184 |
+
|
185 |
+
t = np.linspace(0, 2 * np.pi, 1000)
|
186 |
+
x, y = heart(t)
|
187 |
+
|
188 |
+
plt.figure(figsize=(6, 6))
|
189 |
+
plt.plot(x, y, color='red')
|
190 |
+
plt.axis('equal')
|
191 |
+
plt.axis('off')
|
192 |
+
plt.show()
|
193 |
+
```
|
194 |
+
<|observation|>
|
195 |
+
```result
|
196 |
+
[Image]
|
197 |
+
```
|
198 |
+
<|assistant|>
|
199 |
+
This is a shape of heart. I used parametric equations to describe this shape and drew it using Matplotlib. If you have any other needs or questions, please feel free to let me know.
|
200 |
+
<|user|> # End
|
201 |
+
````
|
README_en.md
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ChatGLM3
|
2 |
+
|
3 |
+
<p align="center">
|
4 |
+
🤗 <a href="https://huggingface.co/THUDM/chatglm3-6b" target="_blank">HF Repo</a> • 🤖 <a href="https://modelscope.cn/models/ZhipuAI/chatglm3-6b" target="_blank">ModelScope</a> • 📔 <a href="https://lslfd0slxc.feishu.cn/wiki/WvQbwIJ9tiPAxGk8ywDck6yfnof" target="_blank">Document</a> • 🐦 <a href="https://twitter.com/thukeg" target="_blank">Twitter</a> • 📃 <a href="https://arxiv.org/abs/2103.10360" target="_blank">[GLM@ACL 22]</a> <a href="https://github.com/THUDM/GLM" target="_blank">[GitHub]</a> • 📃 <a href="https://arxiv.org/abs/2210.02414" target="_blank">[GLM-130B@ICLR 23]</a> <a href="https://github.com/THUDM/GLM-130B" target="_blank">[GitHub]</a> <br>
|
5 |
+
</p>
|
6 |
+
<p align="center">
|
7 |
+
👋 Join our <a href="https://join.slack.com/t/chatglm/shared_invite/zt-25ti5uohv-A_hs~am_D3Q8XPZMpj7wwQ" target="_blank">Slack</a> and <a href="resources/WECHAT.md" target="_blank">WeChat</a>
|
8 |
+
</p>
|
9 |
+
<p align="center">
|
10 |
+
📍Experience the larger-scale ChatGLM model at <a href="https://www.chatglm.cn">chatglm.cn</a>
|
11 |
+
</p>
|
12 |
+
|
13 |
+
## Introduction
|
14 |
+
|
15 |
+
ChatGLM3 is a new generation of pre-trained dialogue models jointly released by Zhipu AI and Tsinghua KEG. ChatGLM3-6B is the open-source model in the ChatGLM3 series, maintaining many excellent features of the first two generations such as smooth dialogue and low deployment threshold, while introducing the following features:
|
16 |
+
|
17 |
+
1. **Stronger Base Model:** The base model of ChatGLM3-6B, ChatGLM3-6B-Base, adopts a more diverse training dataset, more sufficient training steps, and a more reasonable training strategy. Evaluations on datasets from various perspectives such as semantics, mathematics, reasoning, code, and knowledge show that **ChatGLM3-6B-Base has the strongest performance among base models below 10B**.
|
18 |
+
|
19 |
+
2. **More Complete Function Support:** ChatGLM3-6B adopts a newly designed [Prompt format](PROMPT_en.md), supporting multi-turn dialogues as usual. It also natively supports [tool invocation](tool_using/README_en.md) (Function Call), code execution (Code Interpreter), and Agent tasks in complex scenarios.
|
20 |
+
|
21 |
+
3. **More Comprehensive Open-source Series:** In addition to the dialogue model [ChatGLM3-6B](https://huggingface.co/THUDM/chatglm3-6b), the basic model [ChatGLM3-6B-Base](https://huggingface.co/THUDM/chatglm3-6b-base), and the long-text dialogue model [ChatGLM3-6B-32K](https://huggingface.co/THUDM/chatglm3-6b-32k) have also been open-sourced. All these weights are **fully open** for academic research, and **free commercial use is also allowed** after registration via a [questionnaire](https://open.bigmodel.cn/mla/form).
|
22 |
+
|
23 |
+
-----
|
24 |
+
|
25 |
+
The ChatGLM3 open-source model aims to promote the development of large-model technology together with the open-source community. Developers and everyone are earnestly requested to comply with the [open-source protocol](MODEL_LICENSE), and not to use the open-source models, codes, and derivatives for any purposes that might harm the nation and society, and for any services that have not been evaluated and filed for safety. Currently, no applications, including web, Android, Apple iOS, and Windows App, have been developed based on the **ChatGLM3 open-source model** by our project team.
|
26 |
+
|
27 |
+
Although every effort has been made to ensure the compliance and accuracy of the data at various stages of model training, due to the smaller scale of the ChatGLM3-6B model and the influence of probabilistic randomness factors, the accuracy of output content cannot be guaranteed. The model output is also easily misled by user input. **This project does not assume risks and liabilities caused by data security, public opinion risks, or any misleading, abuse, dissemination, and improper use of open-source models and codes.**
|
28 |
+
|
29 |
+
## Model List
|
30 |
+
|
31 |
+
| Model | Seq Length | Download
|
32 |
+
| :---: |:---------------------------:|:-----------------------------------------------------------------------------------------------------------------------------------:
|
33 |
+
| ChatGLM3-6B | 8k | [HuggingFace](https://huggingface.co/THUDM/chatglm3-6b) \| [ModelScope](https://modelscope.cn/models/ZhipuAI/chatglm3-6b)
|
34 |
+
| ChatGLM3-6B-Base | 8k | [HuggingFace](https://huggingface.co/THUDM/chatglm3-6b-base) \| [ModelScope](https://modelscope.cn/models/ZhipuAI/chatglm3-6b-base)
|
35 |
+
| ChatGLM3-6B-32K | 32k | [HuggingFace](https://huggingface.co/THUDM/chatglm3-6b-32k) \| [ModelScope](https://modelscope.cn/models/ZhipuAI/chatglm3-6b-32k)
|
36 |
+
|
37 |
+
## Projects
|
38 |
+
Open source projects that accelerate ChatGLM3:
|
39 |
+
* [chatglm.cpp](https://github.com/li-plus/chatglm.cpp): Real-time inference on your laptop accelerated by quantization, similar to llama.cpp.
|
40 |
+
* [ChatGLM3-TPU](https://github.com/sophgo/ChatGLM3-TPU): Using the TPU accelerated inference solution, it runs about 7.5 token/s in real time on the end-side chip BM1684X (16T@FP16, 16G DDR).
|
41 |
+
|
42 |
+
## Evaluation Results
|
43 |
+
|
44 |
+
### Typical Tasks
|
45 |
+
|
46 |
+
We selected 8 typical Chinese-English datasets and conducted performance tests on the ChatGLM3-6B (base) version.
|
47 |
+
|
48 |
+
| Model | GSM8K | MATH | BBH | MMLU | C-Eval | CMMLU | MBPP | AGIEval |
|
49 |
+
|------------------|:-----:|:----:|:----:|:----:|:------:|:-----:|:----:|:-------:|
|
50 |
+
| ChatGLM2-6B-Base | 32.4 | 6.5 | 33.7 | 47.9 | 51.7 | 50.0 | - | - |
|
51 |
+
| Best Baseline | 52.1 | 13.1 | 45.0 | 60.1 | 63.5 | 62.2 | 47.5 | 45.8 |
|
52 |
+
| ChatGLM3-6B-Base | 72.3 | 25.7 | 66.1 | 61.4 | 69.0 | 67.5 | 52.4 | 53.7 |
|
53 |
+
> "Best Baseline" refers to the pre-trained models that perform best on the corresponding datasets with model parameters below 10B, excluding models that are trained specifically for a single task and do not maintain general capabilities.
|
54 |
+
|
55 |
+
> In the tests of ChatGLM3-6B-Base, BBH used a 3-shot test, GSM8K and MATH that require inference used a 0-shot CoT test, MBPP used a 0-shot generation followed by running test cases to calculate Pass@1, and other multiple-choice type datasets all used a 0-shot test.
|
56 |
+
|
57 |
+
We have conducted manual evaluation tests on ChatGLM3-6B-32K in multiple long-text application scenarios. Compared with the second-generation model, its effect has improved by more than 50% on average. In applications such as paper reading, document summarization, and financial report analysis, this improvement is particularly significant. In addition, we also tested the model on the LongBench evaluation set, and the specific results are shown in the table below.
|
58 |
+
|
59 |
+
| Model | Average | Summary | Single-Doc QA | Multi-Doc QA | Code | Few-shot | Synthetic |
|
60 |
+
|----------------------|:-----:|:----:|:----:|:----:|:------:|:-----:|:-----:|
|
61 |
+
| ChatGLM2-6B-32K | 41.5 | 24.8 | 37.6 | 34.7 | 52.8 | 51.3 | 47.7 |
|
62 |
+
| ChatGLM3-6B-32K | 50.2 | 26.6 | 45.8 | 46.1 | 56.2 | 61.2 | 65 |
|
63 |
+
|
64 |
+
|
65 |
+
## How to Use
|
66 |
+
|
67 |
+
### Environment Installation
|
68 |
+
First, you need to download this repository:
|
69 |
+
```shell
|
70 |
+
git clone https://github.com/THUDM/ChatGLM3
|
71 |
+
cd ChatGLM3
|
72 |
+
```
|
73 |
+
|
74 |
+
Then use pip to install the dependencies:
|
75 |
+
```
|
76 |
+
pip install -r requirements.txt
|
77 |
+
```
|
78 |
+
+ The `transformers` library version should be `4.30.2` and above, and `torch` library should be 2.0 and above to obtain the best inference performance.
|
79 |
+
+ In order to ensure that the version of `torch` is correct, please strictly follow the instructions of [official documentation](https://pytorch.org/get-started/locally/) for installation.
|
80 |
+
+ The `gradio` library version should be the `3.x` version.
|
81 |
+
|
82 |
+
|
83 |
+
### Integrated Demo
|
84 |
+
|
85 |
+
We provide an integrated demo that incorporates the following three functionalities. Please refer to [Integrated Demo](composite_demo/README_en.md) for how to run it.
|
86 |
+
|
87 |
+
- Chat: Dialogue mode, where you can interact with the model.
|
88 |
+
- Tool: Tool mode, where in addition to dialogue, the model can also perform other operations using tools.
|
89 |
+

|
90 |
+
- Code Interpreter: Code interpreter mode, where the model can execute code in a Jupyter environment and obtain results to complete complex tasks.
|
91 |
+

|
92 |
+
|
93 |
+
### Usage
|
94 |
+
|
95 |
+
The ChatGLM model can be called to start a conversation using the following code:
|
96 |
+
|
97 |
+
```python
|
98 |
+
>>> from transformers import AutoTokenizer, AutoModel
|
99 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True)
|
100 |
+
>>> model = AutoModel.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True, device='cuda')
|
101 |
+
>>> model = model.eval()
|
102 |
+
>>> response, history = model.chat(tokenizer, "Hello", history=[])
|
103 |
+
>>> print(response)
|
104 |
+
Hello 👋! I'm ChatGLM3-6B, the artificial intelligence assistant, nice to meet you. Feel free to ask me any questions.
|
105 |
+
>>> response, history = model.chat(tokenizer, "What should I do if I can't sleep at night", history=history)
|
106 |
+
>>> print(response)
|
107 |
+
If you're having trouble sleeping at night, here are a few suggestions that might help:
|
108 |
+
|
109 |
+
1. Create a relaxing sleep environment: Make sure your bedroom is cool, quiet, and dark. Consider using earplugs, a white noise machine, or a fan to help create an optimal environment.
|
110 |
+
2. Establish a bedtime routine: Try to go to bed and wake up at the same time every day, even on weekends. A consistent routine can help regulate your body's internal clock.
|
111 |
+
3. Avoid stimulating activities before bedtime: Avoid using electronic devices, watching TV, or engaging in stimulating activities like exercise or puzzle-solving, as these can interfere with your ability to fall asleep.
|
112 |
+
4. Limit caffeine and alcohol: Avoid consuming caffeine and alcohol close to bedtime, as these can disrupt your sleep patterns.
|
113 |
+
5. Practice relaxation techniques: Try meditation, deep breathing, or progressive muscle relaxation to help calm your mind and body before sleep.
|
114 |
+
6. Consider taking a warm bath or shower: A warm bath or shower can help relax your muscles and promote sleep.
|
115 |
+
7. Get some fresh air: Make sure to get some fresh air during the day, as lack of vitamin D can interfere with sleep quality.
|
116 |
+
|
117 |
+
If you continue to have difficulty sleeping, consult with a healthcare professional for further guidance and support.
|
118 |
+
```
|
119 |
+
|
120 |
+
#### Load Model Locally
|
121 |
+
The above code will automatically download the model implementation and parameters by `transformers`. The complete model implementation is available on [Hugging Face Hub](https://huggingface.co/THUDM/chatglm3-6b). If your network environment is poor, downloading model parameters might take a long time or even fail. In this case, you can first download the model to your local machine, and then load it from there.
|
122 |
+
|
123 |
+
To download the model from Hugging Face Hub, you need to [install Git LFS](https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage) first, then run
|
124 |
+
```Shell
|
125 |
+
git clone https://huggingface.co/THUDM/chatglm3-6b
|
126 |
+
```
|
127 |
+
|
128 |
+
If the download from HuggingFace is slow, you can also download it from [ModelScope](https://modelscope.cn/models/ZhipuAI/chatglm3-6b).
|
129 |
+
|
130 |
+
# Model Fine-tuning
|
131 |
+
Please refer to the dialog model fine-tuning [ChatGLM3-6B fine-tuning example](finetune_chatmodel_demo/README.md), or the base model fine-tuning [ChatGLM3-6B-base fine-tuning example](finetune_basemodel_demo/README.md).
|
132 |
+
Please note that different fine-tuning scripts correspond to different models. Please select the corresponding model according to your needs.
|
133 |
+
|
134 |
+
### Web-based Dialogue Demo
|
135 |
+

|
136 |
+
You can launch a web-based demo using Gradio with the following command:
|
137 |
+
```shell
|
138 |
+
python web_demo.py
|
139 |
+
```
|
140 |
+
|
141 |
+

|
142 |
+
|
143 |
+
You can launch a web-based demo using Streamlit with the following command:
|
144 |
+
```shell
|
145 |
+
streamlit run web_demo2.py
|
146 |
+
```
|
147 |
+
|
148 |
+
The web-based demo will run a Web Server and output an address. You can use it by opening the output address in a browser. Based on tests, the web-based demo using Streamlit runs more smoothly.
|
149 |
+
|
150 |
+
### Command Line Dialogue Demo
|
151 |
+
|
152 |
+

|
153 |
+
|
154 |
+
Run [cli_demo.py](basic_demo/cli_demo.py) in the repository:
|
155 |
+
|
156 |
+
```shell
|
157 |
+
python cli_demo.py
|
158 |
+
```
|
159 |
+
|
160 |
+
The program will interact in the command line, enter instructions in the command line and hit enter to generate a response. Enter `clear` to clear the dialogue history, enter `stop` to terminate the program.
|
161 |
+
|
162 |
+
### API Deployment
|
163 |
+
Thanks to [@xusenlinzy](https://github.com/xusenlinzy) for implementing the OpenAI format streaming API deployment, which can serve as the backend for any ChatGPT-based application, such as [ChatGPT-Next-Web](https://github.com/Yidadaa/ChatGPT-Next-Web). You can deploy it by running [openai_api.py](openai_api_demo/openai_api.py) in the repository:
|
164 |
+
```shell
|
165 |
+
cd openai_api_demo
|
166 |
+
python openai_api.py
|
167 |
+
```
|
168 |
+
Also, we have written a sample code to test the performance of the API calls. This can be tested by running [openai_api_request.py](openai_api_demo/openai_api_request.py) in the repository
|
169 |
+
+ Test with Curl
|
170 |
+
```shell
|
171 |
+
curl -X POST "http://127.0.0.1:8000/v1/chat/completions" \
|
172 |
+
-H "Content-Type: application/json" \\
|
173 |
+
-d "{\"model\": \"chatglm3-6b\", \"messages\": [{\"role\": \"system\", \"content\": \"You are ChatGLM3, a large language model trained by Zhipu. Follow the user's instructions carefully. Respond using markdown.\"}, {\"role\": \"user\", \"content\": \"Hello, tell me a story, about 100 words\"}], \"stream\": false, \"max_title": \"\". false, \"max_tokens\": 100, \"temperature\": 0.8, \"top_p\": 0.8}"
|
174 |
+
````
|
175 |
+
+ Testing with Python
|
176 |
+
```shell
|
177 |
+
cd openai_api_demo
|
178 |
+
python openai_api_request.py
|
179 |
+
```
|
180 |
+
If the test is successful, the model should return a story.
|
181 |
+
|
182 |
+
### Tool Invocation
|
183 |
+
|
184 |
+
For methods of tool invocation, please refer to [Tool Invocation](tool_using/README_en.md).
|
185 |
+
|
186 |
+
## Low-Cost Deployment
|
187 |
+
|
188 |
+
### Model Quantization
|
189 |
+
|
190 |
+
By default, the model is loaded with FP16 precision, running the above code requires about 13GB of VRAM. If your GPU's VRAM is limited, you can try loading the model quantitatively, as follows:
|
191 |
+
|
192 |
+
```python
|
193 |
+
model = AutoModel.from_pretrained("THUDM/chatglm3-6b",trust_remote_code=True).quantize(4).cuda()
|
194 |
+
```
|
195 |
+
|
196 |
+
Model quantization will bring some performance loss. Through testing, ChatGLM3-6B can still perform natural and smooth generation under 4-bit quantization.
|
197 |
+
|
198 |
+
### CPU Deployment
|
199 |
+
|
200 |
+
If you don't have GPU hardware, you can also run inference on the CPU, but the inference speed will be slower. The usage is as follows (requires about 32GB of memory):
|
201 |
+
|
202 |
+
```python
|
203 |
+
model = AutoModel.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True).float()
|
204 |
+
```
|
205 |
+
|
206 |
+
### Mac Deployment
|
207 |
+
|
208 |
+
For Macs equipped with Apple Silicon or AMD GPUs, the MPS backend can be used to run ChatGLM3-6B on the GPU. Refer to Apple's [official instructions](https://developer.apple.com/metal/pytorch) to install PyTorch-Nightly (the correct version number should be 2.x.x.dev2023xxxx, not 2.x.x).
|
209 |
+
|
210 |
+
Currently, only [loading the model locally](README_en.md#load-model-locally) is supported on MacOS. Change the model loading in the code to load locally and use the MPS backend:
|
211 |
+
|
212 |
+
```python
|
213 |
+
model = AutoModel.from_pretrained("your local path", trust_remote_code=True).to('mps')
|
214 |
+
```
|
215 |
+
|
216 |
+
Loading the half-precision ChatGLM3-6B model requires about 13GB of memory. Machines with smaller memory (such as a 16GB memory MacBook Pro) will use virtual memory on the hard disk when there is insufficient free memory, resulting in a significant slowdown in inference speed.
|
217 |
+
|
218 |
+
### Multi-GPU Deployment
|
219 |
+
|
220 |
+
If you have multiple GPUs, but each GPU's VRAM size is not enough to accommodate the complete model, then the model can be split across multiple GPUs. First, install accelerate: `pip install accelerate`, and then load the model through the following methods:
|
221 |
+
|
222 |
+
```python
|
223 |
+
from utils import load_model_on_gpus
|
224 |
+
|
225 |
+
model = load_model_on_gpus("THUDM/chatglm3-6b", num_gpus=2)
|
226 |
+
```
|
227 |
+
|
228 |
+
This allows the model to be deployed on two GPUs for inference. You can change `num_gpus` to the number of GPUs you want to use. It is evenly split by default, but you can also pass the `device_map` parameter to specify it yourself.
|
229 |
+
|
230 |
+
## Citation
|
231 |
+
|
232 |
+
If you find our work helpful, please consider citing the following papers.
|
233 |
+
|
234 |
+
```
|
235 |
+
@article{zeng2022glm,
|
236 |
+
title={Glm-130b: An open bilingual pre-trained model},
|
237 |
+
author={Zeng, Aohan and Liu, Xiao and Du, Zhengxiao and Wang, Zihan and Lai, Hanyu and Ding, Ming and Yang, Zhuoyi and Xu, Yifan and Zheng, Wendi and Xia, Xiao and others},
|
238 |
+
journal={arXiv preprint arXiv:2210.02414},
|
239 |
+
year={2022}
|
240 |
+
}
|
241 |
+
```
|
242 |
+
```
|
243 |
+
@inproceedings{du2022glm,
|
244 |
+
title={GLM: General Language Model Pretraining with Autoregressive Blank Infilling},
|
245 |
+
author={Du, Zhengxiao and Qian, Yujie and Liu, Xiao and Ding, Ming and Qiu, Jiezhong and Yang, Zhilin and Tang, Jie},
|
246 |
+
booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)},
|
247 |
+
pages={320--335},
|
248 |
+
year={2022}
|
249 |
+
}
|
250 |
+
```
|
Untitled.ipynb
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "d1389f08-db10-44c0-acc0-08170ac549d0",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stdout",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"Requirement already satisfied: huggingface_hub in /opt/conda/lib/python3.10/site-packages (0.19.4)\n",
|
14 |
+
"Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from huggingface_hub) (3.9.0)\n",
|
15 |
+
"Requirement already satisfied: fsspec>=2023.5.0 in /opt/conda/lib/python3.10/site-packages (from huggingface_hub) (2023.10.0)\n",
|
16 |
+
"Requirement already satisfied: requests in /opt/conda/lib/python3.10/site-packages (from huggingface_hub) (2.31.0)\n",
|
17 |
+
"Requirement already satisfied: tqdm>=4.42.1 in /opt/conda/lib/python3.10/site-packages (from huggingface_hub) (4.65.0)\n",
|
18 |
+
"Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.10/site-packages (from huggingface_hub) (6.0.1)\n",
|
19 |
+
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/lib/python3.10/site-packages (from huggingface_hub) (4.8.0)\n",
|
20 |
+
"Requirement already satisfied: packaging>=20.9 in /opt/conda/lib/python3.10/site-packages (from huggingface_hub) (23.1)\n",
|
21 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface_hub) (2.0.4)\n",
|
22 |
+
"Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface_hub) (3.4)\n",
|
23 |
+
"Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface_hub) (1.26.18)\n",
|
24 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface_hub) (2023.7.22)\n",
|
25 |
+
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
|
26 |
+
"\u001b[0mNote: you may need to restart the kernel to use updated packages.\n"
|
27 |
+
]
|
28 |
+
}
|
29 |
+
],
|
30 |
+
"source": [
|
31 |
+
"pip install huggingface_hub"
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "code",
|
36 |
+
"execution_count": 2,
|
37 |
+
"id": "37ce339b-5f87-4268-a808-7ec1f9b04055",
|
38 |
+
"metadata": {},
|
39 |
+
"outputs": [
|
40 |
+
{
|
41 |
+
"data": {
|
42 |
+
"application/vnd.jupyter.widget-view+json": {
|
43 |
+
"model_id": "e74cef2b1046454abf03203299de8bd7",
|
44 |
+
"version_major": 2,
|
45 |
+
"version_minor": 0
|
46 |
+
},
|
47 |
+
"text/plain": [
|
48 |
+
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
|
49 |
+
]
|
50 |
+
},
|
51 |
+
"metadata": {},
|
52 |
+
"output_type": "display_data"
|
53 |
+
}
|
54 |
+
],
|
55 |
+
"source": [
|
56 |
+
"from huggingface_hub import login\n",
|
57 |
+
"login()"
|
58 |
+
]
|
59 |
+
},
|
60 |
+
{
|
61 |
+
"cell_type": "code",
|
62 |
+
"execution_count": null,
|
63 |
+
"id": "b338999a-2f2e-45fd-8bfe-ff5c035e79f3",
|
64 |
+
"metadata": {},
|
65 |
+
"outputs": [
|
66 |
+
{
|
67 |
+
"data": {
|
68 |
+
"application/vnd.jupyter.widget-view+json": {
|
69 |
+
"model_id": "ff4e7de5abd24397a2476b50df747fc0",
|
70 |
+
"version_major": 2,
|
71 |
+
"version_minor": 0
|
72 |
+
},
|
73 |
+
"text/plain": [
|
74 |
+
"pytorch_model-00002-of-00007.bin: 0%| | 0.00/1.97G [00:00<?, ?B/s]"
|
75 |
+
]
|
76 |
+
},
|
77 |
+
"metadata": {},
|
78 |
+
"output_type": "display_data"
|
79 |
+
},
|
80 |
+
{
|
81 |
+
"data": {
|
82 |
+
"application/vnd.jupyter.widget-view+json": {
|
83 |
+
"model_id": "4810a0b8ac7c4537ae3b825eaaacd03d",
|
84 |
+
"version_major": 2,
|
85 |
+
"version_minor": 0
|
86 |
+
},
|
87 |
+
"text/plain": [
|
88 |
+
"Upload 11 LFS files: 0%| | 0/11 [00:00<?, ?it/s]"
|
89 |
+
]
|
90 |
+
},
|
91 |
+
"metadata": {},
|
92 |
+
"output_type": "display_data"
|
93 |
+
},
|
94 |
+
{
|
95 |
+
"data": {
|
96 |
+
"application/vnd.jupyter.widget-view+json": {
|
97 |
+
"model_id": "2c378790ca33452bac16db72d1a93b5d",
|
98 |
+
"version_major": 2,
|
99 |
+
"version_minor": 0
|
100 |
+
},
|
101 |
+
"text/plain": [
|
102 |
+
"pytorch_model-00005-of-00007.bin: 0%| | 0.00/1.97G [00:00<?, ?B/s]"
|
103 |
+
]
|
104 |
+
},
|
105 |
+
"metadata": {},
|
106 |
+
"output_type": "display_data"
|
107 |
+
},
|
108 |
+
{
|
109 |
+
"data": {
|
110 |
+
"application/vnd.jupyter.widget-view+json": {
|
111 |
+
"model_id": "19aa1f962f30403d8487851207ebc11f",
|
112 |
+
"version_major": 2,
|
113 |
+
"version_minor": 0
|
114 |
+
},
|
115 |
+
"text/plain": [
|
116 |
+
"pytorch_model-00003-of-00007.bin: 0%| | 0.00/1.93G [00:00<?, ?B/s]"
|
117 |
+
]
|
118 |
+
},
|
119 |
+
"metadata": {},
|
120 |
+
"output_type": "display_data"
|
121 |
+
},
|
122 |
+
{
|
123 |
+
"data": {
|
124 |
+
"application/vnd.jupyter.widget-view+json": {
|
125 |
+
"model_id": "700a7f07898f476c84dee0481d5dd594",
|
126 |
+
"version_major": 2,
|
127 |
+
"version_minor": 0
|
128 |
+
},
|
129 |
+
"text/plain": [
|
130 |
+
"pytorch_model-00001-of-00007.bin: 0%| | 0.00/1.83G [00:00<?, ?B/s]"
|
131 |
+
]
|
132 |
+
},
|
133 |
+
"metadata": {},
|
134 |
+
"output_type": "display_data"
|
135 |
+
},
|
136 |
+
{
|
137 |
+
"data": {
|
138 |
+
"application/vnd.jupyter.widget-view+json": {
|
139 |
+
"model_id": "0dbd5185b81f4cb380005865a8036e8b",
|
140 |
+
"version_major": 2,
|
141 |
+
"version_minor": 0
|
142 |
+
},
|
143 |
+
"text/plain": [
|
144 |
+
"pytorch_model-00004-of-00007.bin: 0%| | 0.00/1.82G [00:00<?, ?B/s]"
|
145 |
+
]
|
146 |
+
},
|
147 |
+
"metadata": {},
|
148 |
+
"output_type": "display_data"
|
149 |
+
}
|
150 |
+
],
|
151 |
+
"source": [
|
152 |
+
"from huggingface_hub import HfApi\n",
|
153 |
+
"api = HfApi()\n",
|
154 |
+
"\n",
|
155 |
+
"api.upload_folder(\n",
|
156 |
+
" folder_path=\"\",\n",
|
157 |
+
" repo_id=\"kevinwang676/ChatGLM3-demo\",\n",
|
158 |
+
" repo_type=\"space\",\n",
|
159 |
+
")"
|
160 |
+
]
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"cell_type": "code",
|
164 |
+
"execution_count": null,
|
165 |
+
"id": "ff1ddc96-5925-4978-bb9c-d016b716ade5",
|
166 |
+
"metadata": {},
|
167 |
+
"outputs": [],
|
168 |
+
"source": []
|
169 |
+
}
|
170 |
+
],
|
171 |
+
"metadata": {
|
172 |
+
"kernelspec": {
|
173 |
+
"display_name": "Python 3 (ipykernel)",
|
174 |
+
"language": "python",
|
175 |
+
"name": "python3"
|
176 |
+
},
|
177 |
+
"language_info": {
|
178 |
+
"codemirror_mode": {
|
179 |
+
"name": "ipython",
|
180 |
+
"version": 3
|
181 |
+
},
|
182 |
+
"file_extension": ".py",
|
183 |
+
"mimetype": "text/x-python",
|
184 |
+
"name": "python",
|
185 |
+
"nbconvert_exporter": "python",
|
186 |
+
"pygments_lexer": "ipython3",
|
187 |
+
"version": "3.10.13"
|
188 |
+
}
|
189 |
+
},
|
190 |
+
"nbformat": 4,
|
191 |
+
"nbformat_minor": 5
|
192 |
+
}
|
basic_demo/cli_demo.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import platform
|
3 |
+
from transformers import AutoTokenizer, AutoModel
|
4 |
+
import torch
|
5 |
+
|
6 |
+
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')
|
7 |
+
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
|
8 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
9 |
+
|
10 |
+
# for Mac Computer like M1
|
11 |
+
# You Need Use Pytorch compiled with Metal
|
12 |
+
# DEVICE = 'mps'
|
13 |
+
|
14 |
+
# for AMD gpu likes MI100 (Not Official Steady Support yet)
|
15 |
+
# You Need Use Pytorch compiled with ROCm
|
16 |
+
# DEVICE = 'cuda'
|
17 |
+
|
18 |
+
# for Intel gpu likes A770 (Not Official Steady Support yet)
|
19 |
+
# You Need Use Pytorch compiled with oneDNN and install intel-extension-for-pytorch
|
20 |
+
# import intel_extension_for_pytorch as ipex
|
21 |
+
# DEVICE = 'xpu'
|
22 |
+
|
23 |
+
# for Moore Threads gpu like MTT S80 (Not Official Steady Support yet)
|
24 |
+
# You Need Use Pytorch compiled with Musa
|
25 |
+
# DEVICE = 'musa'
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
|
30 |
+
if 'cuda' in DEVICE: # AMD, NVIDIA GPU can use Half Precision
|
31 |
+
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True).to(DEVICE).eval()
|
32 |
+
else: # CPU, Intel GPU and other GPU can use Float16 Precision Only
|
33 |
+
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True).float().to(DEVICE).eval()
|
34 |
+
|
35 |
+
os_name = platform.system()
|
36 |
+
clear_command = 'cls' if os_name == 'Windows' else 'clear'
|
37 |
+
stop_stream = False
|
38 |
+
|
39 |
+
welcome_prompt = "欢迎使用 ChatGLM3-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
|
40 |
+
|
41 |
+
|
42 |
+
def build_prompt(history):
|
43 |
+
prompt = welcome_prompt
|
44 |
+
for query, response in history:
|
45 |
+
prompt += f"\n\n用户:{query}"
|
46 |
+
prompt += f"\n\nChatGLM3-6B:{response}"
|
47 |
+
return prompt
|
48 |
+
|
49 |
+
|
50 |
+
def main():
|
51 |
+
past_key_values, history = None, []
|
52 |
+
global stop_stream
|
53 |
+
print(welcome_prompt)
|
54 |
+
while True:
|
55 |
+
query = input("\n用户:")
|
56 |
+
if query.strip() == "stop":
|
57 |
+
break
|
58 |
+
if query.strip() == "clear":
|
59 |
+
past_key_values, history = None, []
|
60 |
+
os.system(clear_command)
|
61 |
+
print(welcome_prompt)
|
62 |
+
continue
|
63 |
+
print("\nChatGLM:", end="")
|
64 |
+
current_length = 0
|
65 |
+
for response, history, past_key_values in model.stream_chat(tokenizer, query, history=history, top_p=1,
|
66 |
+
temperature=0.01,
|
67 |
+
past_key_values=past_key_values,
|
68 |
+
return_past_key_values=True):
|
69 |
+
if stop_stream:
|
70 |
+
stop_stream = False
|
71 |
+
break
|
72 |
+
else:
|
73 |
+
print(response[current_length:], end="", flush=True)
|
74 |
+
current_length = len(response)
|
75 |
+
print("")
|
76 |
+
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
main()
|
basic_demo/cli_demo_bad_word_ids.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script demonstrates how to use the `bad_words_ids` argument to filter out.
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import platform
|
6 |
+
from transformers import AutoTokenizer, AutoModel
|
7 |
+
import torch
|
8 |
+
|
9 |
+
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')
|
10 |
+
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
|
11 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
12 |
+
|
13 |
+
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
|
14 |
+
if 'cuda' in DEVICE: # AMD, NVIDIA GPU can use Half Precision
|
15 |
+
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True).to(DEVICE).eval()
|
16 |
+
else: # CPU, Intel GPU and other GPU can use Float16 Precision Only
|
17 |
+
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True).float().to(DEVICE).eval()
|
18 |
+
|
19 |
+
os_name = platform.system()
|
20 |
+
clear_command = 'cls' if os_name == 'Windows' else 'clear'
|
21 |
+
stop_stream = False
|
22 |
+
|
23 |
+
welcome_prompt = "欢迎使用 ChatGLM3-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
|
24 |
+
|
25 |
+
# 定义不希望出现的词汇, 你可以自定义, 在这个例子中,如果模型回答包含 "你好" 或 "ChatGLM",则会出现这个报错
|
26 |
+
# probability tensor contains either `inf`, `nan` or element < 0
|
27 |
+
|
28 |
+
bad_words = ["你好", "ChatGLM"]
|
29 |
+
# 将这些词汇转换为token ID列表,每个短语是一个子列表
|
30 |
+
bad_word_ids = [tokenizer.encode(bad_word, add_special_tokens=False) for bad_word in bad_words]
|
31 |
+
|
32 |
+
|
33 |
+
def build_prompt(history):
|
34 |
+
prompt = welcome_prompt
|
35 |
+
for query, response in history:
|
36 |
+
prompt += f"\n\n用户:{query}"
|
37 |
+
prompt += f"\n\nChatGLM3-6B:{response}"
|
38 |
+
return prompt
|
39 |
+
|
40 |
+
def main():
|
41 |
+
past_key_values, history = None, []
|
42 |
+
global stop_stream
|
43 |
+
print(welcome_prompt)
|
44 |
+
while True:
|
45 |
+
query = input("\n用户:")
|
46 |
+
if query.strip().lower() == "stop":
|
47 |
+
break
|
48 |
+
if query.strip().lower() == "clear":
|
49 |
+
past_key_values, history = None, []
|
50 |
+
os.system(clear_command)
|
51 |
+
print(welcome_prompt)
|
52 |
+
continue
|
53 |
+
|
54 |
+
# Attempt to generate a response
|
55 |
+
try:
|
56 |
+
print("\nChatGLM:", end="")
|
57 |
+
current_length = 0
|
58 |
+
response_generated = False
|
59 |
+
for response, history, past_key_values in model.stream_chat(
|
60 |
+
tokenizer, query, history=history, top_p=1,
|
61 |
+
temperature=0.01,
|
62 |
+
past_key_values=past_key_values,
|
63 |
+
return_past_key_values=True,
|
64 |
+
bad_words_ids=bad_word_ids # assuming this is implemented correctly
|
65 |
+
):
|
66 |
+
response_generated = True
|
67 |
+
# Check if the response contains any bad words
|
68 |
+
if any(bad_word in response for bad_word in bad_words):
|
69 |
+
print("我的回答涉嫌了bad word")
|
70 |
+
break # Break the loop if a bad word is detected
|
71 |
+
|
72 |
+
# Otherwise, print the generated response
|
73 |
+
print(response[current_length:], end="", flush=True)
|
74 |
+
current_length = len(response)
|
75 |
+
if not response_generated:
|
76 |
+
print("没有生成任何回答。")
|
77 |
+
except RuntimeError as e:
|
78 |
+
print(f"生成文本时发生错误:{e},这可能是涉及到设定的敏感词汇")
|
79 |
+
|
80 |
+
print("")
|
81 |
+
|
82 |
+
if __name__ == "__main__":
|
83 |
+
main()
|
basic_demo/utils.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Dict, Union, Optional
|
3 |
+
from torch.nn import Module
|
4 |
+
from transformers import AutoModel
|
5 |
+
|
6 |
+
|
7 |
+
def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
|
8 |
+
# transformer.word_embeddings 占用1层
|
9 |
+
# transformer.final_layernorm 和 lm_head 占用1层
|
10 |
+
# transformer.layers 占用 28 层
|
11 |
+
# 总共30层分配到num_gpus张卡上
|
12 |
+
num_trans_layers = 28
|
13 |
+
per_gpu_layers = 30 / num_gpus
|
14 |
+
|
15 |
+
# bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError
|
16 |
+
# windows下 model.device 会被设置成 transformer.word_embeddings.device
|
17 |
+
# linux下 model.device 会被设置成 lm_head.device
|
18 |
+
# 在调用chat或者stream_chat时,input_ids会被放到model.device上
|
19 |
+
# 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError
|
20 |
+
# 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上
|
21 |
+
# 本文件来源于https://github.com/THUDM/ChatGLM-6B/blob/main/utils.py
|
22 |
+
# 仅此处做少许修改以支持ChatGLM3
|
23 |
+
device_map = {
|
24 |
+
'transformer.embedding.word_embeddings': 0,
|
25 |
+
'transformer.encoder.final_layernorm': 0,
|
26 |
+
'transformer.output_layer': 0,
|
27 |
+
'transformer.rotary_pos_emb': 0,
|
28 |
+
'lm_head': 0
|
29 |
+
}
|
30 |
+
|
31 |
+
used = 2
|
32 |
+
gpu_target = 0
|
33 |
+
for i in range(num_trans_layers):
|
34 |
+
if used >= per_gpu_layers:
|
35 |
+
gpu_target += 1
|
36 |
+
used = 0
|
37 |
+
assert gpu_target < num_gpus
|
38 |
+
device_map[f'transformer.encoder.layers.{i}'] = gpu_target
|
39 |
+
used += 1
|
40 |
+
|
41 |
+
return device_map
|
42 |
+
|
43 |
+
|
44 |
+
def load_model_on_gpus(checkpoint_path: Union[str, os.PathLike], num_gpus: int = 2,
|
45 |
+
device_map: Optional[Dict[str, int]] = None, **kwargs) -> Module:
|
46 |
+
if num_gpus < 2 and device_map is None:
|
47 |
+
model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half().cuda()
|
48 |
+
else:
|
49 |
+
from accelerate import dispatch_model
|
50 |
+
|
51 |
+
model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half()
|
52 |
+
|
53 |
+
if device_map is None:
|
54 |
+
device_map = auto_configure_device_map(num_gpus)
|
55 |
+
|
56 |
+
model = dispatch_model(model, device_map=device_map)
|
57 |
+
|
58 |
+
return model
|
basic_demo/web_demo.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from transformers import AutoModel, AutoTokenizer
|
3 |
+
import gradio as gr
|
4 |
+
import mdtex2html
|
5 |
+
from utils import load_model_on_gpus
|
6 |
+
import torch
|
7 |
+
|
8 |
+
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')
|
9 |
+
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
|
10 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
11 |
+
|
12 |
+
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
|
13 |
+
if 'cuda' in DEVICE: # AMD, NVIDIA GPU can use Half Precision
|
14 |
+
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True).to(DEVICE).eval()
|
15 |
+
else: # CPU, Intel GPU and other GPU can use Float16 Precision Only
|
16 |
+
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True).float().to(DEVICE).eval()
|
17 |
+
|
18 |
+
# 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量
|
19 |
+
# from utils import load_model_on_gpus
|
20 |
+
# model = load_model_on_gpus("THUDM/chatglm3-6b", num_gpus=2)
|
21 |
+
|
22 |
+
"""Override Chatbot.postprocess"""
|
23 |
+
|
24 |
+
def postprocess(self, y):
|
25 |
+
if y is None:
|
26 |
+
return []
|
27 |
+
for i, (message, response) in enumerate(y):
|
28 |
+
y[i] = (
|
29 |
+
None if message is None else mdtex2html.convert((message)),
|
30 |
+
None if response is None else mdtex2html.convert(response),
|
31 |
+
)
|
32 |
+
return y
|
33 |
+
|
34 |
+
|
35 |
+
gr.Chatbot.postprocess = postprocess
|
36 |
+
|
37 |
+
|
38 |
+
def parse_text(text):
|
39 |
+
"""copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
|
40 |
+
lines = text.split("\n")
|
41 |
+
lines = [line for line in lines if line != ""]
|
42 |
+
count = 0
|
43 |
+
for i, line in enumerate(lines):
|
44 |
+
if "```" in line:
|
45 |
+
count += 1
|
46 |
+
items = line.split('`')
|
47 |
+
if count % 2 == 1:
|
48 |
+
lines[i] = f'<pre><code class="language-{items[-1]}">'
|
49 |
+
else:
|
50 |
+
lines[i] = f'<br></code></pre>'
|
51 |
+
else:
|
52 |
+
if i > 0:
|
53 |
+
if count % 2 == 1:
|
54 |
+
line = line.replace("`", "\`")
|
55 |
+
line = line.replace("<", "<")
|
56 |
+
line = line.replace(">", ">")
|
57 |
+
line = line.replace(" ", " ")
|
58 |
+
line = line.replace("*", "*")
|
59 |
+
line = line.replace("_", "_")
|
60 |
+
line = line.replace("-", "-")
|
61 |
+
line = line.replace(".", ".")
|
62 |
+
line = line.replace("!", "!")
|
63 |
+
line = line.replace("(", "(")
|
64 |
+
line = line.replace(")", ")")
|
65 |
+
line = line.replace("$", "$")
|
66 |
+
lines[i] = "<br>" + line
|
67 |
+
text = "".join(lines)
|
68 |
+
return text
|
69 |
+
|
70 |
+
|
71 |
+
def predict(input, chatbot, max_length, top_p, temperature, history, past_key_values):
|
72 |
+
chatbot.append((parse_text(input), ""))
|
73 |
+
for response, history, past_key_values in model.stream_chat(tokenizer, input, history,
|
74 |
+
past_key_values=past_key_values,
|
75 |
+
return_past_key_values=True,
|
76 |
+
max_length=max_length, top_p=top_p,
|
77 |
+
temperature=temperature):
|
78 |
+
chatbot[-1] = (parse_text(input), parse_text(response))
|
79 |
+
|
80 |
+
yield chatbot, history, past_key_values
|
81 |
+
|
82 |
+
|
83 |
+
def reset_user_input():
|
84 |
+
return gr.update(value='')
|
85 |
+
|
86 |
+
|
87 |
+
def reset_state():
|
88 |
+
return [], [], None
|
89 |
+
|
90 |
+
|
91 |
+
with gr.Blocks() as demo:
|
92 |
+
gr.HTML("""<h1 align="center">ChatGLM3-6B</h1>""")
|
93 |
+
|
94 |
+
chatbot = gr.Chatbot()
|
95 |
+
with gr.Row():
|
96 |
+
with gr.Column(scale=4):
|
97 |
+
with gr.Column(scale=12):
|
98 |
+
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
|
99 |
+
container=False)
|
100 |
+
with gr.Column(min_width=32, scale=1):
|
101 |
+
submitBtn = gr.Button("Submit", variant="primary")
|
102 |
+
with gr.Column(scale=1):
|
103 |
+
emptyBtn = gr.Button("Clear History")
|
104 |
+
max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
|
105 |
+
top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
|
106 |
+
temperature = gr.Slider(0, 1, value=0.6, step=0.01, label="Temperature", interactive=True)
|
107 |
+
|
108 |
+
history = gr.State([])
|
109 |
+
past_key_values = gr.State(None)
|
110 |
+
|
111 |
+
submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history, past_key_values],
|
112 |
+
[chatbot, history, past_key_values], show_progress=True)
|
113 |
+
submitBtn.click(reset_user_input, [], [user_input])
|
114 |
+
|
115 |
+
emptyBtn.click(reset_state, outputs=[chatbot, history, past_key_values], show_progress=True)
|
116 |
+
|
117 |
+
demo.queue().launch(share=False, server_name="127.0.0.1", server_port=8501, inbrowser=True)
|
basic_demo/web_demo2.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import streamlit as st
|
3 |
+
import torch
|
4 |
+
from transformers import AutoModel, AutoTokenizer
|
5 |
+
|
6 |
+
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')
|
7 |
+
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
|
8 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
9 |
+
# 设置页面标题、图标和布局
|
10 |
+
st.set_page_config(
|
11 |
+
page_title="ChatGLM3-6B 演示",
|
12 |
+
page_icon=":robot:",
|
13 |
+
layout="wide"
|
14 |
+
)
|
15 |
+
|
16 |
+
@st.cache_resource
|
17 |
+
def get_model():
|
18 |
+
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
|
19 |
+
if 'cuda' in DEVICE: # AMD, NVIDIA GPU can use Half Precision
|
20 |
+
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True).to(DEVICE).eval()
|
21 |
+
else: # CPU, Intel GPU and other GPU can use Float16 Precision Only
|
22 |
+
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True).float().to(DEVICE).eval()
|
23 |
+
# 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量
|
24 |
+
# from utils import load_model_on_gpus
|
25 |
+
# model = load_model_on_gpus("THUDM/chatglm3-6b", num_gpus=2)
|
26 |
+
return tokenizer, model
|
27 |
+
|
28 |
+
# 加载Chatglm3的model和tokenizer
|
29 |
+
tokenizer, model = get_model()
|
30 |
+
|
31 |
+
# 初始化历史记录和past key values
|
32 |
+
if "history" not in st.session_state:
|
33 |
+
st.session_state.history = []
|
34 |
+
if "past_key_values" not in st.session_state:
|
35 |
+
st.session_state.past_key_values = None
|
36 |
+
|
37 |
+
# 设置max_length、top_p和temperature
|
38 |
+
max_length = st.sidebar.slider("max_length", 0, 32768, 8192, step=1)
|
39 |
+
top_p = st.sidebar.slider("top_p", 0.0, 1.0, 0.8, step=0.01)
|
40 |
+
temperature = st.sidebar.slider("temperature", 0.0, 1.0, 0.6, step=0.01)
|
41 |
+
|
42 |
+
# 清理会话历史
|
43 |
+
buttonClean = st.sidebar.button("清理会话历史", key="clean")
|
44 |
+
if buttonClean:
|
45 |
+
st.session_state.history = []
|
46 |
+
st.session_state.past_key_values = None
|
47 |
+
if torch.cuda.is_available():
|
48 |
+
torch.cuda.empty_cache()
|
49 |
+
st.rerun()
|
50 |
+
|
51 |
+
# 渲染聊天历史记录
|
52 |
+
for i, message in enumerate(st.session_state.history):
|
53 |
+
if message["role"] == "user":
|
54 |
+
with st.chat_message(name="user", avatar="user"):
|
55 |
+
st.markdown(message["content"])
|
56 |
+
else:
|
57 |
+
with st.chat_message(name="assistant", avatar="assistant"):
|
58 |
+
st.markdown(message["content"])
|
59 |
+
|
60 |
+
# 输入框和输出框
|
61 |
+
with st.chat_message(name="user", avatar="user"):
|
62 |
+
input_placeholder = st.empty()
|
63 |
+
with st.chat_message(name="assistant", avatar="assistant"):
|
64 |
+
message_placeholder = st.empty()
|
65 |
+
|
66 |
+
# 获取用户输入
|
67 |
+
prompt_text = st.chat_input("请输入您的问题")
|
68 |
+
|
69 |
+
# 如果用户输入了内容,则生成回复
|
70 |
+
if prompt_text:
|
71 |
+
|
72 |
+
input_placeholder.markdown(prompt_text)
|
73 |
+
history = st.session_state.history
|
74 |
+
past_key_values = st.session_state.past_key_values
|
75 |
+
for response, history, past_key_values in model.stream_chat(
|
76 |
+
tokenizer,
|
77 |
+
prompt_text,
|
78 |
+
history,
|
79 |
+
past_key_values=past_key_values,
|
80 |
+
max_length=max_length,
|
81 |
+
top_p=top_p,
|
82 |
+
temperature=temperature,
|
83 |
+
return_past_key_values=True,
|
84 |
+
):
|
85 |
+
message_placeholder.markdown(response)
|
86 |
+
|
87 |
+
# 更新历史记录和past key values
|
88 |
+
st.session_state.history = history
|
89 |
+
st.session_state.past_key_values = past_key_values
|
chatglm3-6b/.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
chatglm3-6b/MODEL_LICENSE
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
The ChatGLM3-6B License
|
2 |
+
|
3 |
+
1. 定义
|
4 |
+
|
5 |
+
“许可方”是指分发其软件的 ChatGLM3-6B 模型团队。
|
6 |
+
|
7 |
+
“软件”是指根据本许可提供的 ChatGLM3-6B 模型参数。
|
8 |
+
|
9 |
+
2. 许可授予
|
10 |
+
|
11 |
+
根据本许可的条款和条件,许可方特此授予您非排他性、全球性、不可转让、不可再许可、可撤销、免版税的版权许可。
|
12 |
+
|
13 |
+
上述版权声明和本许可声明应包含在本软件的所有副本或重要部分中。
|
14 |
+
|
15 |
+
3.限制
|
16 |
+
|
17 |
+
您不得出于任何军事或非法目的使用、复制、修改、合并、发布、分发、复制或创建本软件的全部或部分衍生作品。
|
18 |
+
|
19 |
+
您不得利用本软件从事任何危害国家安全和国家统一、危害社会公共利益、侵犯人身权益的行为。
|
20 |
+
|
21 |
+
4.免责声明
|
22 |
+
|
23 |
+
本软件“按原样”提供,不提供任何明示或暗示的保证,包括但不限于对适销性、特定用途的适用性和非侵权性的保证。 在任何情况下,作者或版权持有人均不对任何索赔、损害或其他责任负责,无论是在合同诉讼、侵权行为还是其他方面,由软件或软件的使用或其他交易引起、由软件引起或与之相关 软件。
|
24 |
+
|
25 |
+
5. 责任限制
|
26 |
+
|
27 |
+
除适用法律禁止的范围外,在任何情况下且根据任何法律理论,无论是基于侵权行为、疏忽、合同、责任或其他原因,任何许可方均不对您承担任何直接、间接、特殊、偶然、示范性、 或间接损害,或任何其他商业损失,即使许可人已被告知此类损害的可能性。
|
28 |
+
|
29 |
+
6.争议解决
|
30 |
+
|
31 |
+
本许可受中华人民共和国法律管辖并按其解释。 因本许可引起的或与本许可有关的任何争议应提交北京市海淀区人民法院。
|
32 |
+
|
33 |
+
请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 [email protected] 与我们联系。
|
34 |
+
|
35 |
+
1. Definitions
|
36 |
+
|
37 |
+
“Licensor” means the ChatGLM3-6B Model Team that distributes its Software.
|
38 |
+
|
39 |
+
“Software” means the ChatGLM3-6B model parameters made available under this license.
|
40 |
+
|
41 |
+
2. License Grant
|
42 |
+
|
43 |
+
Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software.
|
44 |
+
|
45 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
46 |
+
|
47 |
+
3. Restriction
|
48 |
+
|
49 |
+
You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any military, or illegal purposes.
|
50 |
+
|
51 |
+
You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings.
|
52 |
+
|
53 |
+
4. Disclaimer
|
54 |
+
|
55 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
56 |
+
|
57 |
+
5. Limitation of Liability
|
58 |
+
|
59 |
+
EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
|
60 |
+
|
61 |
+
6. Dispute Resolution
|
62 |
+
|
63 |
+
This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing.
|
64 |
+
|
65 |
+
Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at [email protected].
|
chatglm3-6b/README.md
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language:
|
3 |
+
- zh
|
4 |
+
- en
|
5 |
+
tags:
|
6 |
+
- glm
|
7 |
+
- chatglm
|
8 |
+
- thudm
|
9 |
+
---
|
10 |
+
# ChatGLM3-6B
|
11 |
+
<p align="center">
|
12 |
+
💻 <a href="https://github.com/THUDM/ChatGLM" target="_blank">Github Repo</a> • 🐦 <a href="https://twitter.com/thukeg" target="_blank">Twitter</a> • 📃 <a href="https://arxiv.org/abs/2103.10360" target="_blank">[GLM@ACL 22]</a> <a href="https://github.com/THUDM/GLM" target="_blank">[GitHub]</a> • 📃 <a href="https://arxiv.org/abs/2210.02414" target="_blank">[GLM-130B@ICLR 23]</a> <a href="https://github.com/THUDM/GLM-130B" target="_blank">[GitHub]</a> <br>
|
13 |
+
</p>
|
14 |
+
|
15 |
+
<p align="center">
|
16 |
+
👋 Join our <a href="https://join.slack.com/t/chatglm/shared_invite/zt-25ti5uohv-A_hs~am_D3Q8XPZMpj7wwQ" target="_blank">Slack</a> and <a href="https://github.com/THUDM/ChatGLM/blob/main/resources/WECHAT.md" target="_blank">WeChat</a>
|
17 |
+
</p>
|
18 |
+
<p align="center">
|
19 |
+
📍Experience the larger-scale ChatGLM model at <a href="https://www.chatglm.cn">chatglm.cn</a>
|
20 |
+
</p>
|
21 |
+
|
22 |
+
## 介绍 (Introduction)
|
23 |
+
ChatGLM3-6B 是 ChatGLM 系列最新一代的开源模型,在保留了前两代模型对话流畅、部署门槛低等众多优秀特性的基础上,ChatGLM3-6B 引入了如下特性:
|
24 |
+
|
25 |
+
1. **更强大的基础模型:** ChatGLM3-6B 的基础模型 ChatGLM3-6B-Base 采用了更多样的训练数据、更充分的训练步数和更合理的训练策略。在语义、数学、推理、代码、知识等不同角度的数据集上测评显示,ChatGLM3-6B-Base 具有在 10B 以下的预训练模型中最强的性能。
|
26 |
+
2. **更完整的功能支持:** ChatGLM3-6B 采用了全新设计的 [Prompt 格式](https://github.com/THUDM/ChatGLM3/blob/main/PROMPT.md),除正常的多轮对话外。同时原生支持[工具调用](https://github.com/THUDM/ChatGLM3/blob/main/tool_using/README.md)(Function Call)、代码执行(Code Interpreter)和 Agent 任务等复杂场景。
|
27 |
+
3. **更全面的开源序列:** 除了对话模型 ChatGLM3-6B 外,还开源了基础模型 ChatGLM-6B-Base、长文本对话模型 ChatGLM3-6B-32K。以上所有权重对学术研究**完全开放**,在填写[问卷](https://open.bigmodel.cn/mla/form)进行登记后**亦允许免费商业使用**。
|
28 |
+
|
29 |
+
ChatGLM3-6B is the latest open-source model in the ChatGLM series. While retaining many excellent features such as smooth dialogue and low deployment threshold from the previous two generations, ChatGLM3-6B introduces the following features:
|
30 |
+
|
31 |
+
1. **More Powerful Base Model:** The base model of ChatGLM3-6B, ChatGLM3-6B-Base, employs a more diverse training dataset, more sufficient training steps, and a more reasonable training strategy. Evaluations on datasets such as semantics, mathematics, reasoning, code, knowledge, etc., show that ChatGLM3-6B-Base has the strongest performance among pre-trained models under 10B.
|
32 |
+
2. **More Comprehensive Function Support:** ChatGLM3-6B adopts a newly designed [Prompt format](https://github.com/THUDM/ChatGLM3/blob/main/PROMPT_en.md), in addition to the normal multi-turn dialogue. It also natively supports [function call](https://github.com/THUDM/ChatGLM3/blob/main/tool_using/README_en.md), code interpreter, and complex scenarios such as agent tasks.
|
33 |
+
3. **More Comprehensive Open-source Series:** In addition to the dialogue model ChatGLM3-6B, the base model ChatGLM-6B-Base and the long-text dialogue model ChatGLM3-6B-32K are also open-sourced. All the weights are **fully open** for academic research, and after completing the [questionnaire](https://open.bigmodel.cn/mla/form) registration, they are also **allowed for free commercial use**.
|
34 |
+
|
35 |
+
## 软件依赖 (Dependencies)
|
36 |
+
|
37 |
+
```shell
|
38 |
+
pip install protobuf transformers==4.30.2 cpm_kernels torch>=2.0 gradio mdtex2html sentencepiece accelerate
|
39 |
+
```
|
40 |
+
|
41 |
+
## 代码调用 (Code Usage)
|
42 |
+
|
43 |
+
可以通过如下代码调用 ChatGLM3-6B 模型来生成对话:
|
44 |
+
|
45 |
+
You can generate dialogue by invoking the ChatGLM3-6B model with the following code:
|
46 |
+
|
47 |
+
```ipython
|
48 |
+
>>> from transformers import AutoTokenizer, AutoModel
|
49 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True)
|
50 |
+
>>> model = AutoModel.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True).half().cuda()
|
51 |
+
>>> model = model.eval()
|
52 |
+
>>> response, history = model.chat(tokenizer, "你好", history=[])
|
53 |
+
>>> print(response)
|
54 |
+
你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。
|
55 |
+
>>> response, history = model.chat(tokenizer, "晚上睡不着应该怎么办", history=history)
|
56 |
+
>>> print(response)
|
57 |
+
晚上睡不着可能会让你感到焦虑或不舒服,但以下是一些可以帮助你入睡的方法:
|
58 |
+
|
59 |
+
1. 制定规律的睡眠时间表:保持规律的睡眠时间表可以帮助你建立健康的睡眠习惯,使你更容易入睡。尽量在每天的相同时间上床,并在同一时间起床。
|
60 |
+
2. 创造一个舒适的睡眠环境:确保睡眠环境舒适,安静,黑暗且温度适宜。可以使用舒适的床上用品,并保持房间通风。
|
61 |
+
3. 放松身心:在睡前做些放松的活动,例如泡个热水澡,听些轻柔的音��,阅读一些有趣的书籍等,有助于缓解紧张和焦虑,使你更容易入睡。
|
62 |
+
4. 避免饮用含有咖啡因的饮料:咖啡因是一种刺激性物质,会影响你的睡眠质量。尽量避免在睡前饮用含有咖啡因的饮料,例如咖啡,茶和可乐。
|
63 |
+
5. 避免在床上做与睡眠无关的事情:在床上做些与睡眠无关的事情,例如看电影,玩游戏或工作等,可能会干扰你的睡眠。
|
64 |
+
6. 尝试呼吸技巧:深呼吸是一种放松技巧,可以帮助你缓解紧张和焦虑,使你更容易入睡。试着慢慢吸气,保持几秒钟,然后缓慢呼气。
|
65 |
+
|
66 |
+
如果这些方法无法帮助你入睡,你可以考虑咨询医生或睡眠专家,寻求进一步的建议。
|
67 |
+
```
|
68 |
+
|
69 |
+
关于更多的使用说明,包括如何运行命令行和网页版本的 DEMO,以及使用模型量化以节省显存,请参考我们的 [Github Repo](https://github.com/THUDM/ChatGLM)。
|
70 |
+
|
71 |
+
For more instructions, including how to run CLI and web demos, and model quantization, please refer to our [Github Repo](https://github.com/THUDM/ChatGLM).
|
72 |
+
|
73 |
+
|
74 |
+
## 协议 (License)
|
75 |
+
|
76 |
+
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源,ChatGLM3-6B 模型的权重的使用则需要遵循 [Model License](MODEL_LICENSE)。
|
77 |
+
|
78 |
+
The code in this repository is open-sourced under the [Apache-2.0 license](LICENSE), while the use of the ChatGLM3-6B model weights needs to comply with the [Model License](MODEL_LICENSE).
|
79 |
+
|
80 |
+
## 引用 (Citation)
|
81 |
+
|
82 |
+
如果你觉得我们的工作有帮助的话,请考虑引用下列论文。
|
83 |
+
|
84 |
+
If you find our work helpful, please consider citing the following papers.
|
85 |
+
|
86 |
+
```
|
87 |
+
@article{zeng2022glm,
|
88 |
+
title={Glm-130b: An open bilingual pre-trained model},
|
89 |
+
author={Zeng, Aohan and Liu, Xiao and Du, Zhengxiao and Wang, Zihan and Lai, Hanyu and Ding, Ming and Yang, Zhuoyi and Xu, Yifan and Zheng, Wendi and Xia, Xiao and others},
|
90 |
+
journal={arXiv preprint arXiv:2210.02414},
|
91 |
+
year={2022}
|
92 |
+
}
|
93 |
+
```
|
94 |
+
```
|
95 |
+
@inproceedings{du2022glm,
|
96 |
+
title={GLM: General Language Model Pretraining with Autoregressive Blank Infilling},
|
97 |
+
author={Du, Zhengxiao and Qian, Yujie and Liu, Xiao and Ding, Ming and Qiu, Jiezhong and Yang, Zhilin and Tang, Jie},
|
98 |
+
booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)},
|
99 |
+
pages={320--335},
|
100 |
+
year={2022}
|
101 |
+
}
|
102 |
+
```
|
chatglm3-6b/config.json
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "THUDM/chatglm3-6b",
|
3 |
+
"model_type": "chatglm",
|
4 |
+
"architectures": [
|
5 |
+
"ChatGLMModel"
|
6 |
+
],
|
7 |
+
"auto_map": {
|
8 |
+
"AutoConfig": "configuration_chatglm.ChatGLMConfig",
|
9 |
+
"AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration",
|
10 |
+
"AutoModelForCausalLM": "modeling_chatglm.ChatGLMForConditionalGeneration",
|
11 |
+
"AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration",
|
12 |
+
"AutoModelForSequenceClassification": "modeling_chatglm.ChatGLMForSequenceClassification"
|
13 |
+
},
|
14 |
+
"add_bias_linear": false,
|
15 |
+
"add_qkv_bias": true,
|
16 |
+
"apply_query_key_layer_scaling": true,
|
17 |
+
"apply_residual_connection_post_layernorm": false,
|
18 |
+
"attention_dropout": 0.0,
|
19 |
+
"attention_softmax_in_fp32": true,
|
20 |
+
"bias_dropout_fusion": true,
|
21 |
+
"ffn_hidden_size": 13696,
|
22 |
+
"fp32_residual_connection": false,
|
23 |
+
"hidden_dropout": 0.0,
|
24 |
+
"hidden_size": 4096,
|
25 |
+
"kv_channels": 128,
|
26 |
+
"layernorm_epsilon": 1e-05,
|
27 |
+
"multi_query_attention": true,
|
28 |
+
"multi_query_group_num": 2,
|
29 |
+
"num_attention_heads": 32,
|
30 |
+
"num_layers": 28,
|
31 |
+
"original_rope": true,
|
32 |
+
"padded_vocab_size": 65024,
|
33 |
+
"post_layer_norm": true,
|
34 |
+
"rmsnorm": true,
|
35 |
+
"seq_length": 8192,
|
36 |
+
"use_cache": true,
|
37 |
+
"torch_dtype": "float16",
|
38 |
+
"transformers_version": "4.30.2",
|
39 |
+
"tie_word_embeddings": false,
|
40 |
+
"eos_token_id": 2,
|
41 |
+
"pad_token_id": 0
|
42 |
+
}
|
chatglm3-6b/configuration_chatglm.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
|
3 |
+
|
4 |
+
class ChatGLMConfig(PretrainedConfig):
|
5 |
+
model_type = "chatglm"
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
num_layers=28,
|
9 |
+
padded_vocab_size=65024,
|
10 |
+
hidden_size=4096,
|
11 |
+
ffn_hidden_size=13696,
|
12 |
+
kv_channels=128,
|
13 |
+
num_attention_heads=32,
|
14 |
+
seq_length=2048,
|
15 |
+
hidden_dropout=0.0,
|
16 |
+
classifier_dropout=None,
|
17 |
+
attention_dropout=0.0,
|
18 |
+
layernorm_epsilon=1e-5,
|
19 |
+
rmsnorm=True,
|
20 |
+
apply_residual_connection_post_layernorm=False,
|
21 |
+
post_layer_norm=True,
|
22 |
+
add_bias_linear=False,
|
23 |
+
add_qkv_bias=False,
|
24 |
+
bias_dropout_fusion=True,
|
25 |
+
multi_query_attention=False,
|
26 |
+
multi_query_group_num=1,
|
27 |
+
apply_query_key_layer_scaling=True,
|
28 |
+
attention_softmax_in_fp32=True,
|
29 |
+
fp32_residual_connection=False,
|
30 |
+
quantization_bit=0,
|
31 |
+
pre_seq_len=None,
|
32 |
+
prefix_projection=False,
|
33 |
+
**kwargs
|
34 |
+
):
|
35 |
+
self.num_layers = num_layers
|
36 |
+
self.vocab_size = padded_vocab_size
|
37 |
+
self.padded_vocab_size = padded_vocab_size
|
38 |
+
self.hidden_size = hidden_size
|
39 |
+
self.ffn_hidden_size = ffn_hidden_size
|
40 |
+
self.kv_channels = kv_channels
|
41 |
+
self.num_attention_heads = num_attention_heads
|
42 |
+
self.seq_length = seq_length
|
43 |
+
self.hidden_dropout = hidden_dropout
|
44 |
+
self.classifier_dropout = classifier_dropout
|
45 |
+
self.attention_dropout = attention_dropout
|
46 |
+
self.layernorm_epsilon = layernorm_epsilon
|
47 |
+
self.rmsnorm = rmsnorm
|
48 |
+
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
|
49 |
+
self.post_layer_norm = post_layer_norm
|
50 |
+
self.add_bias_linear = add_bias_linear
|
51 |
+
self.add_qkv_bias = add_qkv_bias
|
52 |
+
self.bias_dropout_fusion = bias_dropout_fusion
|
53 |
+
self.multi_query_attention = multi_query_attention
|
54 |
+
self.multi_query_group_num = multi_query_group_num
|
55 |
+
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
|
56 |
+
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
|
57 |
+
self.fp32_residual_connection = fp32_residual_connection
|
58 |
+
self.quantization_bit = quantization_bit
|
59 |
+
self.pre_seq_len = pre_seq_len
|
60 |
+
self.prefix_projection = prefix_projection
|
61 |
+
super().__init__(**kwargs)
|
chatglm3-6b/modeling_chatglm.py
ADDED
@@ -0,0 +1,1293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" PyTorch ChatGLM model. """
|
2 |
+
|
3 |
+
import math
|
4 |
+
import copy
|
5 |
+
import warnings
|
6 |
+
import re
|
7 |
+
import sys
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.utils.checkpoint
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from torch import nn
|
13 |
+
from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
|
14 |
+
from torch.nn.utils import skip_init
|
15 |
+
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
|
16 |
+
from copy import deepcopy
|
17 |
+
|
18 |
+
from transformers.modeling_outputs import (
|
19 |
+
BaseModelOutputWithPast,
|
20 |
+
CausalLMOutputWithPast,
|
21 |
+
SequenceClassifierOutputWithPast,
|
22 |
+
)
|
23 |
+
from transformers.modeling_utils import PreTrainedModel
|
24 |
+
from transformers.utils import logging
|
25 |
+
from transformers.generation.logits_process import LogitsProcessor
|
26 |
+
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
|
27 |
+
|
28 |
+
from .configuration_chatglm import ChatGLMConfig
|
29 |
+
|
30 |
+
# flags required to enable jit fusion kernels
|
31 |
+
|
32 |
+
if sys.platform != 'darwin':
|
33 |
+
torch._C._jit_set_profiling_mode(False)
|
34 |
+
torch._C._jit_set_profiling_executor(False)
|
35 |
+
torch._C._jit_override_can_fuse_on_cpu(True)
|
36 |
+
torch._C._jit_override_can_fuse_on_gpu(True)
|
37 |
+
|
38 |
+
logger = logging.get_logger(__name__)
|
39 |
+
|
40 |
+
_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
|
41 |
+
_CONFIG_FOR_DOC = "ChatGLMConfig"
|
42 |
+
|
43 |
+
CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
44 |
+
"THUDM/chatglm3-6b",
|
45 |
+
# See all ChatGLM models at https://huggingface.co/models?filter=chatglm
|
46 |
+
]
|
47 |
+
|
48 |
+
|
49 |
+
def default_init(cls, *args, **kwargs):
|
50 |
+
return cls(*args, **kwargs)
|
51 |
+
|
52 |
+
|
53 |
+
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
54 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
55 |
+
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
56 |
+
scores.zero_()
|
57 |
+
scores[..., 5] = 5e4
|
58 |
+
return scores
|
59 |
+
|
60 |
+
|
61 |
+
class PrefixEncoder(torch.nn.Module):
|
62 |
+
"""
|
63 |
+
The torch.nn model to encode the prefix
|
64 |
+
Input shape: (batch-size, prefix-length)
|
65 |
+
Output shape: (batch-size, prefix-length, 2*layers*hidden)
|
66 |
+
"""
|
67 |
+
|
68 |
+
def __init__(self, config: ChatGLMConfig):
|
69 |
+
super().__init__()
|
70 |
+
self.prefix_projection = config.prefix_projection
|
71 |
+
if self.prefix_projection:
|
72 |
+
# Use a two-layer MLP to encode the prefix
|
73 |
+
kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2
|
74 |
+
self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
|
75 |
+
self.trans = torch.nn.Sequential(
|
76 |
+
torch.nn.Linear(kv_size, config.hidden_size),
|
77 |
+
torch.nn.Tanh(),
|
78 |
+
torch.nn.Linear(config.hidden_size, kv_size)
|
79 |
+
)
|
80 |
+
else:
|
81 |
+
self.embedding = torch.nn.Embedding(config.pre_seq_len,
|
82 |
+
config.num_layers * config.kv_channels * config.multi_query_group_num * 2)
|
83 |
+
|
84 |
+
def forward(self, prefix: torch.Tensor):
|
85 |
+
if self.prefix_projection:
|
86 |
+
prefix_tokens = self.embedding(prefix)
|
87 |
+
past_key_values = self.trans(prefix_tokens)
|
88 |
+
else:
|
89 |
+
past_key_values = self.embedding(prefix)
|
90 |
+
return past_key_values
|
91 |
+
|
92 |
+
|
93 |
+
def split_tensor_along_last_dim(
|
94 |
+
tensor: torch.Tensor,
|
95 |
+
num_partitions: int,
|
96 |
+
contiguous_split_chunks: bool = False,
|
97 |
+
) -> List[torch.Tensor]:
|
98 |
+
"""Split a tensor along its last dimension.
|
99 |
+
|
100 |
+
Arguments:
|
101 |
+
tensor: input tensor.
|
102 |
+
num_partitions: number of partitions to split the tensor
|
103 |
+
contiguous_split_chunks: If True, make each chunk contiguous
|
104 |
+
in memory.
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
A list of Tensors
|
108 |
+
"""
|
109 |
+
# Get the size and dimension.
|
110 |
+
last_dim = tensor.dim() - 1
|
111 |
+
last_dim_size = tensor.size()[last_dim] // num_partitions
|
112 |
+
# Split.
|
113 |
+
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
114 |
+
# Note: torch.split does not create contiguous tensors by default.
|
115 |
+
if contiguous_split_chunks:
|
116 |
+
return tuple(chunk.contiguous() for chunk in tensor_list)
|
117 |
+
|
118 |
+
return tensor_list
|
119 |
+
|
120 |
+
|
121 |
+
class RotaryEmbedding(nn.Module):
|
122 |
+
def __init__(self, dim, original_impl=False, device=None, dtype=None):
|
123 |
+
super().__init__()
|
124 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
|
125 |
+
self.register_buffer("inv_freq", inv_freq)
|
126 |
+
self.dim = dim
|
127 |
+
self.original_impl = original_impl
|
128 |
+
|
129 |
+
def forward_impl(
|
130 |
+
self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
|
131 |
+
):
|
132 |
+
"""Enhanced Transformer with Rotary Position Embedding.
|
133 |
+
|
134 |
+
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
|
135 |
+
transformers/rope/__init__.py. MIT License:
|
136 |
+
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
|
137 |
+
"""
|
138 |
+
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
|
139 |
+
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))
|
140 |
+
|
141 |
+
# Create position indexes `[0, 1, ..., seq_len - 1]`
|
142 |
+
seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)
|
143 |
+
|
144 |
+
# Calculate the product of position index and $\theta_i$
|
145 |
+
idx_theta = torch.outer(seq_idx, theta).float()
|
146 |
+
|
147 |
+
cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
|
148 |
+
|
149 |
+
# this is to mimic the behaviour of complex32, else we will get different results
|
150 |
+
if dtype in (torch.float16, torch.bfloat16, torch.int8):
|
151 |
+
cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
|
152 |
+
return cache
|
153 |
+
|
154 |
+
def forward(self, max_seq_len, offset=0):
|
155 |
+
return self.forward_impl(
|
156 |
+
max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
|
157 |
+
)
|
158 |
+
|
159 |
+
|
160 |
+
@torch.jit.script
|
161 |
+
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
|
162 |
+
# x: [sq, b, np, hn]
|
163 |
+
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
|
164 |
+
rot_dim = rope_cache.shape[-2] * 2
|
165 |
+
x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
|
166 |
+
# truncate to support variable sizes
|
167 |
+
rope_cache = rope_cache[:sq]
|
168 |
+
xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
|
169 |
+
rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
|
170 |
+
x_out2 = torch.stack(
|
171 |
+
[
|
172 |
+
xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
|
173 |
+
xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
|
174 |
+
],
|
175 |
+
-1,
|
176 |
+
)
|
177 |
+
x_out2 = x_out2.flatten(3)
|
178 |
+
return torch.cat((x_out2, x_pass), dim=-1)
|
179 |
+
|
180 |
+
|
181 |
+
class RMSNorm(torch.nn.Module):
|
182 |
+
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
|
183 |
+
super().__init__()
|
184 |
+
self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
|
185 |
+
self.eps = eps
|
186 |
+
|
187 |
+
def forward(self, hidden_states: torch.Tensor):
|
188 |
+
input_dtype = hidden_states.dtype
|
189 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
190 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
191 |
+
|
192 |
+
return (self.weight * hidden_states).to(input_dtype)
|
193 |
+
|
194 |
+
|
195 |
+
class CoreAttention(torch.nn.Module):
|
196 |
+
def __init__(self, config: ChatGLMConfig, layer_number):
|
197 |
+
super(CoreAttention, self).__init__()
|
198 |
+
|
199 |
+
self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
|
200 |
+
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
|
201 |
+
if self.apply_query_key_layer_scaling:
|
202 |
+
self.attention_softmax_in_fp32 = True
|
203 |
+
self.layer_number = max(1, layer_number)
|
204 |
+
|
205 |
+
projection_size = config.kv_channels * config.num_attention_heads
|
206 |
+
|
207 |
+
# Per attention head and per partition values.
|
208 |
+
self.hidden_size_per_partition = projection_size
|
209 |
+
self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
|
210 |
+
self.num_attention_heads_per_partition = config.num_attention_heads
|
211 |
+
|
212 |
+
coeff = None
|
213 |
+
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
|
214 |
+
if self.apply_query_key_layer_scaling:
|
215 |
+
coeff = self.layer_number
|
216 |
+
self.norm_factor *= coeff
|
217 |
+
self.coeff = coeff
|
218 |
+
|
219 |
+
self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
|
220 |
+
|
221 |
+
def forward(self, query_layer, key_layer, value_layer, attention_mask):
|
222 |
+
pytorch_major_version = int(torch.__version__.split('.')[0])
|
223 |
+
if pytorch_major_version >= 2:
|
224 |
+
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
|
225 |
+
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
226 |
+
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
|
227 |
+
is_causal=True)
|
228 |
+
else:
|
229 |
+
if attention_mask is not None:
|
230 |
+
attention_mask = ~attention_mask
|
231 |
+
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
|
232 |
+
attention_mask)
|
233 |
+
context_layer = context_layer.permute(2, 0, 1, 3)
|
234 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
235 |
+
context_layer = context_layer.reshape(*new_context_layer_shape)
|
236 |
+
else:
|
237 |
+
# Raw attention scores
|
238 |
+
|
239 |
+
# [b, np, sq, sk]
|
240 |
+
output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
|
241 |
+
|
242 |
+
# [sq, b, np, hn] -> [sq, b * np, hn]
|
243 |
+
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
|
244 |
+
# [sk, b, np, hn] -> [sk, b * np, hn]
|
245 |
+
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
|
246 |
+
|
247 |
+
# preallocting input tensor: [b * np, sq, sk]
|
248 |
+
matmul_input_buffer = torch.empty(
|
249 |
+
output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
|
250 |
+
device=query_layer.device
|
251 |
+
)
|
252 |
+
|
253 |
+
# Raw attention scores. [b * np, sq, sk]
|
254 |
+
matmul_result = torch.baddbmm(
|
255 |
+
matmul_input_buffer,
|
256 |
+
query_layer.transpose(0, 1), # [b * np, sq, hn]
|
257 |
+
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
|
258 |
+
beta=0.0,
|
259 |
+
alpha=(1.0 / self.norm_factor),
|
260 |
+
)
|
261 |
+
|
262 |
+
# change view to [b, np, sq, sk]
|
263 |
+
attention_scores = matmul_result.view(*output_size)
|
264 |
+
|
265 |
+
# ===========================
|
266 |
+
# Attention probs and dropout
|
267 |
+
# ===========================
|
268 |
+
|
269 |
+
# attention scores and attention mask [b, np, sq, sk]
|
270 |
+
if self.attention_softmax_in_fp32:
|
271 |
+
attention_scores = attention_scores.float()
|
272 |
+
if self.coeff is not None:
|
273 |
+
attention_scores = attention_scores * self.coeff
|
274 |
+
if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
|
275 |
+
attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
|
276 |
+
device=attention_scores.device, dtype=torch.bool)
|
277 |
+
attention_mask.tril_()
|
278 |
+
attention_mask = ~attention_mask
|
279 |
+
if attention_mask is not None:
|
280 |
+
attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
|
281 |
+
attention_probs = F.softmax(attention_scores, dim=-1)
|
282 |
+
attention_probs = attention_probs.type_as(value_layer)
|
283 |
+
|
284 |
+
# This is actually dropping out entire tokens to attend to, which might
|
285 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
286 |
+
attention_probs = self.attention_dropout(attention_probs)
|
287 |
+
# =========================
|
288 |
+
# Context layer. [sq, b, hp]
|
289 |
+
# =========================
|
290 |
+
|
291 |
+
# value_layer -> context layer.
|
292 |
+
# [sk, b, np, hn] --> [b, np, sq, hn]
|
293 |
+
|
294 |
+
# context layer shape: [b, np, sq, hn]
|
295 |
+
output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
|
296 |
+
# change view [sk, b * np, hn]
|
297 |
+
value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
|
298 |
+
# change view [b * np, sq, sk]
|
299 |
+
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
|
300 |
+
# matmul: [b * np, sq, hn]
|
301 |
+
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
|
302 |
+
# change view [b, np, sq, hn]
|
303 |
+
context_layer = context_layer.view(*output_size)
|
304 |
+
# [b, np, sq, hn] --> [sq, b, np, hn]
|
305 |
+
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
306 |
+
# [sq, b, np, hn] --> [sq, b, hp]
|
307 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
308 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
309 |
+
|
310 |
+
return context_layer
|
311 |
+
|
312 |
+
|
313 |
+
class SelfAttention(torch.nn.Module):
|
314 |
+
"""Parallel self-attention layer abstract class.
|
315 |
+
|
316 |
+
Self-attention layer takes input with size [s, b, h]
|
317 |
+
and returns output of the same size.
|
318 |
+
"""
|
319 |
+
|
320 |
+
def __init__(self, config: ChatGLMConfig, layer_number, device=None):
|
321 |
+
super(SelfAttention, self).__init__()
|
322 |
+
self.layer_number = max(1, layer_number)
|
323 |
+
|
324 |
+
self.projection_size = config.kv_channels * config.num_attention_heads
|
325 |
+
|
326 |
+
# Per attention head and per partition values.
|
327 |
+
self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
|
328 |
+
self.num_attention_heads_per_partition = config.num_attention_heads
|
329 |
+
|
330 |
+
self.multi_query_attention = config.multi_query_attention
|
331 |
+
self.qkv_hidden_size = 3 * self.projection_size
|
332 |
+
if self.multi_query_attention:
|
333 |
+
self.num_multi_query_groups_per_partition = config.multi_query_group_num
|
334 |
+
self.qkv_hidden_size = (
|
335 |
+
self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
|
336 |
+
)
|
337 |
+
self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
|
338 |
+
bias=config.add_bias_linear or config.add_qkv_bias,
|
339 |
+
device=device, **_config_to_kwargs(config)
|
340 |
+
)
|
341 |
+
|
342 |
+
self.core_attention = CoreAttention(config, self.layer_number)
|
343 |
+
|
344 |
+
# Output.
|
345 |
+
self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
|
346 |
+
device=device, **_config_to_kwargs(config)
|
347 |
+
)
|
348 |
+
|
349 |
+
def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
|
350 |
+
if self.multi_query_attention:
|
351 |
+
num_attention_heads = self.num_multi_query_groups_per_partition
|
352 |
+
else:
|
353 |
+
num_attention_heads = self.num_attention_heads_per_partition
|
354 |
+
return torch.empty(
|
355 |
+
inference_max_sequence_len,
|
356 |
+
batch_size,
|
357 |
+
num_attention_heads,
|
358 |
+
self.hidden_size_per_attention_head,
|
359 |
+
dtype=dtype,
|
360 |
+
device=device,
|
361 |
+
)
|
362 |
+
|
363 |
+
def forward(
|
364 |
+
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
|
365 |
+
):
|
366 |
+
# hidden_states: [sq, b, h]
|
367 |
+
|
368 |
+
# =================================================
|
369 |
+
# Pre-allocate memory for key-values for inference.
|
370 |
+
# =================================================
|
371 |
+
# =====================
|
372 |
+
# Query, Key, and Value
|
373 |
+
# =====================
|
374 |
+
|
375 |
+
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
|
376 |
+
mixed_x_layer = self.query_key_value(hidden_states)
|
377 |
+
|
378 |
+
if self.multi_query_attention:
|
379 |
+
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
|
380 |
+
[
|
381 |
+
self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
|
382 |
+
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
|
383 |
+
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
|
384 |
+
],
|
385 |
+
dim=-1,
|
386 |
+
)
|
387 |
+
query_layer = query_layer.view(
|
388 |
+
query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
|
389 |
+
)
|
390 |
+
key_layer = key_layer.view(
|
391 |
+
key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
|
392 |
+
)
|
393 |
+
value_layer = value_layer.view(
|
394 |
+
value_layer.size()[:-1]
|
395 |
+
+ (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
|
396 |
+
)
|
397 |
+
else:
|
398 |
+
new_tensor_shape = mixed_x_layer.size()[:-1] + \
|
399 |
+
(self.num_attention_heads_per_partition,
|
400 |
+
3 * self.hidden_size_per_attention_head)
|
401 |
+
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
402 |
+
|
403 |
+
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
|
404 |
+
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
|
405 |
+
|
406 |
+
# apply relative positional encoding (rotary embedding)
|
407 |
+
if rotary_pos_emb is not None:
|
408 |
+
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
|
409 |
+
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
|
410 |
+
|
411 |
+
# adjust key and value for inference
|
412 |
+
if kv_cache is not None:
|
413 |
+
cache_k, cache_v = kv_cache
|
414 |
+
key_layer = torch.cat((cache_k, key_layer), dim=0)
|
415 |
+
value_layer = torch.cat((cache_v, value_layer), dim=0)
|
416 |
+
if use_cache:
|
417 |
+
kv_cache = (key_layer, value_layer)
|
418 |
+
else:
|
419 |
+
kv_cache = None
|
420 |
+
|
421 |
+
if self.multi_query_attention:
|
422 |
+
key_layer = key_layer.unsqueeze(-2)
|
423 |
+
key_layer = key_layer.expand(
|
424 |
+
-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
|
425 |
+
)
|
426 |
+
key_layer = key_layer.contiguous().view(
|
427 |
+
key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
|
428 |
+
)
|
429 |
+
value_layer = value_layer.unsqueeze(-2)
|
430 |
+
value_layer = value_layer.expand(
|
431 |
+
-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
|
432 |
+
)
|
433 |
+
value_layer = value_layer.contiguous().view(
|
434 |
+
value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
|
435 |
+
)
|
436 |
+
|
437 |
+
# ==================================
|
438 |
+
# core attention computation
|
439 |
+
# ==================================
|
440 |
+
|
441 |
+
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
|
442 |
+
|
443 |
+
# =================
|
444 |
+
# Output. [sq, b, h]
|
445 |
+
# =================
|
446 |
+
|
447 |
+
output = self.dense(context_layer)
|
448 |
+
|
449 |
+
return output, kv_cache
|
450 |
+
|
451 |
+
|
452 |
+
def _config_to_kwargs(args):
|
453 |
+
common_kwargs = {
|
454 |
+
"dtype": args.torch_dtype,
|
455 |
+
}
|
456 |
+
return common_kwargs
|
457 |
+
|
458 |
+
|
459 |
+
class MLP(torch.nn.Module):
|
460 |
+
"""MLP.
|
461 |
+
|
462 |
+
MLP will take the input with h hidden state, project it to 4*h
|
463 |
+
hidden dimension, perform nonlinear transformation, and project the
|
464 |
+
state back into h hidden dimension.
|
465 |
+
"""
|
466 |
+
|
467 |
+
def __init__(self, config: ChatGLMConfig, device=None):
|
468 |
+
super(MLP, self).__init__()
|
469 |
+
|
470 |
+
self.add_bias = config.add_bias_linear
|
471 |
+
|
472 |
+
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
|
473 |
+
self.dense_h_to_4h = nn.Linear(
|
474 |
+
config.hidden_size,
|
475 |
+
config.ffn_hidden_size * 2,
|
476 |
+
bias=self.add_bias,
|
477 |
+
device=device,
|
478 |
+
**_config_to_kwargs(config)
|
479 |
+
)
|
480 |
+
|
481 |
+
def swiglu(x):
|
482 |
+
x = torch.chunk(x, 2, dim=-1)
|
483 |
+
return F.silu(x[0]) * x[1]
|
484 |
+
|
485 |
+
self.activation_func = swiglu
|
486 |
+
|
487 |
+
# Project back to h.
|
488 |
+
self.dense_4h_to_h = nn.Linear(
|
489 |
+
config.ffn_hidden_size,
|
490 |
+
config.hidden_size,
|
491 |
+
bias=self.add_bias,
|
492 |
+
device=device,
|
493 |
+
**_config_to_kwargs(config)
|
494 |
+
)
|
495 |
+
|
496 |
+
def forward(self, hidden_states):
|
497 |
+
# [s, b, 4hp]
|
498 |
+
intermediate_parallel = self.dense_h_to_4h(hidden_states)
|
499 |
+
intermediate_parallel = self.activation_func(intermediate_parallel)
|
500 |
+
# [s, b, h]
|
501 |
+
output = self.dense_4h_to_h(intermediate_parallel)
|
502 |
+
return output
|
503 |
+
|
504 |
+
|
505 |
+
class GLMBlock(torch.nn.Module):
|
506 |
+
"""A single transformer layer.
|
507 |
+
|
508 |
+
Transformer layer takes input with size [s, b, h] and returns an
|
509 |
+
output of the same size.
|
510 |
+
"""
|
511 |
+
|
512 |
+
def __init__(self, config: ChatGLMConfig, layer_number, device=None):
|
513 |
+
super(GLMBlock, self).__init__()
|
514 |
+
self.layer_number = layer_number
|
515 |
+
|
516 |
+
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
|
517 |
+
|
518 |
+
self.fp32_residual_connection = config.fp32_residual_connection
|
519 |
+
|
520 |
+
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
521 |
+
# Layernorm on the input data.
|
522 |
+
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
523 |
+
dtype=config.torch_dtype)
|
524 |
+
|
525 |
+
# Self attention.
|
526 |
+
self.self_attention = SelfAttention(config, layer_number, device=device)
|
527 |
+
self.hidden_dropout = config.hidden_dropout
|
528 |
+
|
529 |
+
# Layernorm on the attention output
|
530 |
+
self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
531 |
+
dtype=config.torch_dtype)
|
532 |
+
|
533 |
+
# MLP
|
534 |
+
self.mlp = MLP(config, device=device)
|
535 |
+
|
536 |
+
def forward(
|
537 |
+
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,
|
538 |
+
):
|
539 |
+
# hidden_states: [s, b, h]
|
540 |
+
|
541 |
+
# Layer norm at the beginning of the transformer layer.
|
542 |
+
layernorm_output = self.input_layernorm(hidden_states)
|
543 |
+
# Self attention.
|
544 |
+
attention_output, kv_cache = self.self_attention(
|
545 |
+
layernorm_output,
|
546 |
+
attention_mask,
|
547 |
+
rotary_pos_emb,
|
548 |
+
kv_cache=kv_cache,
|
549 |
+
use_cache=use_cache
|
550 |
+
)
|
551 |
+
|
552 |
+
# Residual connection.
|
553 |
+
if self.apply_residual_connection_post_layernorm:
|
554 |
+
residual = layernorm_output
|
555 |
+
else:
|
556 |
+
residual = hidden_states
|
557 |
+
|
558 |
+
layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
|
559 |
+
layernorm_input = residual + layernorm_input
|
560 |
+
|
561 |
+
# Layer norm post the self attention.
|
562 |
+
layernorm_output = self.post_attention_layernorm(layernorm_input)
|
563 |
+
|
564 |
+
# MLP.
|
565 |
+
mlp_output = self.mlp(layernorm_output)
|
566 |
+
|
567 |
+
# Second residual connection.
|
568 |
+
if self.apply_residual_connection_post_layernorm:
|
569 |
+
residual = layernorm_output
|
570 |
+
else:
|
571 |
+
residual = layernorm_input
|
572 |
+
|
573 |
+
output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
|
574 |
+
output = residual + output
|
575 |
+
|
576 |
+
return output, kv_cache
|
577 |
+
|
578 |
+
|
579 |
+
class GLMTransformer(torch.nn.Module):
|
580 |
+
"""Transformer class."""
|
581 |
+
|
582 |
+
def __init__(self, config: ChatGLMConfig, device=None):
|
583 |
+
super(GLMTransformer, self).__init__()
|
584 |
+
|
585 |
+
self.fp32_residual_connection = config.fp32_residual_connection
|
586 |
+
self.post_layer_norm = config.post_layer_norm
|
587 |
+
|
588 |
+
# Number of layers.
|
589 |
+
self.num_layers = config.num_layers
|
590 |
+
|
591 |
+
# Transformer layers.
|
592 |
+
def build_layer(layer_number):
|
593 |
+
return GLMBlock(config, layer_number, device=device)
|
594 |
+
|
595 |
+
self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
|
596 |
+
|
597 |
+
if self.post_layer_norm:
|
598 |
+
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
599 |
+
# Final layer norm before output.
|
600 |
+
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
601 |
+
dtype=config.torch_dtype)
|
602 |
+
|
603 |
+
self.gradient_checkpointing = False
|
604 |
+
|
605 |
+
def _get_layer(self, layer_number):
|
606 |
+
return self.layers[layer_number]
|
607 |
+
|
608 |
+
def forward(
|
609 |
+
self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,
|
610 |
+
use_cache: Optional[bool] = True,
|
611 |
+
output_hidden_states: Optional[bool] = False,
|
612 |
+
):
|
613 |
+
if not kv_caches:
|
614 |
+
kv_caches = [None for _ in range(self.num_layers)]
|
615 |
+
presents = () if use_cache else None
|
616 |
+
if self.gradient_checkpointing and self.training:
|
617 |
+
if use_cache:
|
618 |
+
logger.warning_once(
|
619 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
620 |
+
)
|
621 |
+
use_cache = False
|
622 |
+
|
623 |
+
all_self_attentions = None
|
624 |
+
all_hidden_states = () if output_hidden_states else None
|
625 |
+
for index in range(self.num_layers):
|
626 |
+
if output_hidden_states:
|
627 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
628 |
+
|
629 |
+
layer = self._get_layer(index)
|
630 |
+
if self.gradient_checkpointing and self.training:
|
631 |
+
layer_ret = torch.utils.checkpoint.checkpoint(
|
632 |
+
layer,
|
633 |
+
hidden_states,
|
634 |
+
attention_mask,
|
635 |
+
rotary_pos_emb,
|
636 |
+
kv_caches[index],
|
637 |
+
use_cache
|
638 |
+
)
|
639 |
+
else:
|
640 |
+
layer_ret = layer(
|
641 |
+
hidden_states,
|
642 |
+
attention_mask,
|
643 |
+
rotary_pos_emb,
|
644 |
+
kv_cache=kv_caches[index],
|
645 |
+
use_cache=use_cache
|
646 |
+
)
|
647 |
+
hidden_states, kv_cache = layer_ret
|
648 |
+
if use_cache:
|
649 |
+
presents = presents + (kv_cache,)
|
650 |
+
|
651 |
+
if output_hidden_states:
|
652 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
653 |
+
|
654 |
+
# Final layer norm.
|
655 |
+
if self.post_layer_norm:
|
656 |
+
hidden_states = self.final_layernorm(hidden_states)
|
657 |
+
|
658 |
+
return hidden_states, presents, all_hidden_states, all_self_attentions
|
659 |
+
|
660 |
+
|
661 |
+
class ChatGLMPreTrainedModel(PreTrainedModel):
|
662 |
+
"""
|
663 |
+
An abstract class to handle weights initialization and
|
664 |
+
a simple interface for downloading and loading pretrained models.
|
665 |
+
"""
|
666 |
+
|
667 |
+
is_parallelizable = False
|
668 |
+
supports_gradient_checkpointing = True
|
669 |
+
config_class = ChatGLMConfig
|
670 |
+
base_model_prefix = "transformer"
|
671 |
+
_no_split_modules = ["GLMBlock"]
|
672 |
+
|
673 |
+
def _init_weights(self, module: nn.Module):
|
674 |
+
"""Initialize the weights."""
|
675 |
+
return
|
676 |
+
|
677 |
+
def get_masks(self, input_ids, past_key_values, padding_mask=None):
|
678 |
+
batch_size, seq_length = input_ids.shape
|
679 |
+
full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
|
680 |
+
full_attention_mask.tril_()
|
681 |
+
past_length = 0
|
682 |
+
if past_key_values:
|
683 |
+
past_length = past_key_values[0][0].shape[0]
|
684 |
+
if past_length:
|
685 |
+
full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
|
686 |
+
device=input_ids.device), full_attention_mask), dim=-1)
|
687 |
+
if padding_mask is not None:
|
688 |
+
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
|
689 |
+
if not past_length and padding_mask is not None:
|
690 |
+
full_attention_mask -= padding_mask.unsqueeze(-1) - 1
|
691 |
+
full_attention_mask = (full_attention_mask < 0.5).bool()
|
692 |
+
full_attention_mask.unsqueeze_(1)
|
693 |
+
return full_attention_mask
|
694 |
+
|
695 |
+
def get_position_ids(self, input_ids, device):
|
696 |
+
batch_size, seq_length = input_ids.shape
|
697 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
698 |
+
return position_ids
|
699 |
+
|
700 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
701 |
+
if isinstance(module, GLMTransformer):
|
702 |
+
module.gradient_checkpointing = value
|
703 |
+
|
704 |
+
|
705 |
+
class Embedding(torch.nn.Module):
|
706 |
+
"""Language model embeddings."""
|
707 |
+
|
708 |
+
def __init__(self, config: ChatGLMConfig, device=None):
|
709 |
+
super(Embedding, self).__init__()
|
710 |
+
|
711 |
+
self.hidden_size = config.hidden_size
|
712 |
+
# Word embeddings (parallel).
|
713 |
+
self.word_embeddings = nn.Embedding(
|
714 |
+
config.padded_vocab_size,
|
715 |
+
self.hidden_size,
|
716 |
+
dtype=config.torch_dtype,
|
717 |
+
device=device
|
718 |
+
)
|
719 |
+
self.fp32_residual_connection = config.fp32_residual_connection
|
720 |
+
|
721 |
+
def forward(self, input_ids):
|
722 |
+
# Embeddings.
|
723 |
+
words_embeddings = self.word_embeddings(input_ids)
|
724 |
+
embeddings = words_embeddings
|
725 |
+
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
|
726 |
+
embeddings = embeddings.transpose(0, 1).contiguous()
|
727 |
+
# If the input flag for fp32 residual connection is set, convert for float.
|
728 |
+
if self.fp32_residual_connection:
|
729 |
+
embeddings = embeddings.float()
|
730 |
+
return embeddings
|
731 |
+
|
732 |
+
|
733 |
+
class ChatGLMModel(ChatGLMPreTrainedModel):
|
734 |
+
def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
|
735 |
+
super().__init__(config)
|
736 |
+
if empty_init:
|
737 |
+
init_method = skip_init
|
738 |
+
else:
|
739 |
+
init_method = default_init
|
740 |
+
init_kwargs = {}
|
741 |
+
if device is not None:
|
742 |
+
init_kwargs["device"] = device
|
743 |
+
self.embedding = init_method(Embedding, config, **init_kwargs)
|
744 |
+
self.num_layers = config.num_layers
|
745 |
+
self.multi_query_group_num = config.multi_query_group_num
|
746 |
+
self.kv_channels = config.kv_channels
|
747 |
+
|
748 |
+
# Rotary positional embeddings
|
749 |
+
self.seq_length = config.seq_length
|
750 |
+
rotary_dim = (
|
751 |
+
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
752 |
+
)
|
753 |
+
|
754 |
+
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
|
755 |
+
dtype=config.torch_dtype)
|
756 |
+
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
|
757 |
+
self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
|
758 |
+
dtype=config.torch_dtype, **init_kwargs)
|
759 |
+
self.pre_seq_len = config.pre_seq_len
|
760 |
+
self.prefix_projection = config.prefix_projection
|
761 |
+
if self.pre_seq_len is not None:
|
762 |
+
for param in self.parameters():
|
763 |
+
param.requires_grad = False
|
764 |
+
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
|
765 |
+
self.prefix_encoder = PrefixEncoder(config)
|
766 |
+
self.dropout = torch.nn.Dropout(0.1)
|
767 |
+
|
768 |
+
def get_input_embeddings(self):
|
769 |
+
return self.embedding.word_embeddings
|
770 |
+
|
771 |
+
def get_prompt(self, batch_size, device, dtype=torch.half):
|
772 |
+
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
|
773 |
+
past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
|
774 |
+
past_key_values = past_key_values.view(
|
775 |
+
batch_size,
|
776 |
+
self.pre_seq_len,
|
777 |
+
self.num_layers * 2,
|
778 |
+
self.multi_query_group_num,
|
779 |
+
self.kv_channels
|
780 |
+
)
|
781 |
+
# seq_len, b, nh, hidden_size
|
782 |
+
past_key_values = self.dropout(past_key_values)
|
783 |
+
past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
|
784 |
+
return past_key_values
|
785 |
+
|
786 |
+
def forward(
|
787 |
+
self,
|
788 |
+
input_ids,
|
789 |
+
position_ids: Optional[torch.Tensor] = None,
|
790 |
+
attention_mask: Optional[torch.BoolTensor] = None,
|
791 |
+
full_attention_mask: Optional[torch.BoolTensor] = None,
|
792 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
793 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
794 |
+
use_cache: Optional[bool] = None,
|
795 |
+
output_hidden_states: Optional[bool] = None,
|
796 |
+
return_dict: Optional[bool] = None,
|
797 |
+
):
|
798 |
+
output_hidden_states = (
|
799 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
800 |
+
)
|
801 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
802 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
803 |
+
|
804 |
+
batch_size, seq_length = input_ids.shape
|
805 |
+
|
806 |
+
if inputs_embeds is None:
|
807 |
+
inputs_embeds = self.embedding(input_ids)
|
808 |
+
|
809 |
+
if self.pre_seq_len is not None:
|
810 |
+
if past_key_values is None:
|
811 |
+
past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
|
812 |
+
dtype=inputs_embeds.dtype)
|
813 |
+
if attention_mask is not None:
|
814 |
+
attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)),
|
815 |
+
attention_mask], dim=-1)
|
816 |
+
|
817 |
+
if full_attention_mask is None:
|
818 |
+
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
|
819 |
+
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
|
820 |
+
|
821 |
+
# Rotary positional embeddings
|
822 |
+
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
823 |
+
if position_ids is not None:
|
824 |
+
rotary_pos_emb = rotary_pos_emb[position_ids]
|
825 |
+
else:
|
826 |
+
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
|
827 |
+
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
828 |
+
|
829 |
+
# Run encoder.
|
830 |
+
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
|
831 |
+
inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
|
832 |
+
kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
|
833 |
+
)
|
834 |
+
|
835 |
+
if not return_dict:
|
836 |
+
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
837 |
+
|
838 |
+
return BaseModelOutputWithPast(
|
839 |
+
last_hidden_state=hidden_states,
|
840 |
+
past_key_values=presents,
|
841 |
+
hidden_states=all_hidden_states,
|
842 |
+
attentions=all_self_attentions,
|
843 |
+
)
|
844 |
+
|
845 |
+
def quantize(self, weight_bit_width: int):
|
846 |
+
from .quantization import quantize
|
847 |
+
quantize(self.encoder, weight_bit_width)
|
848 |
+
return self
|
849 |
+
|
850 |
+
|
851 |
+
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
852 |
+
def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
|
853 |
+
super().__init__(config)
|
854 |
+
|
855 |
+
self.max_sequence_length = config.max_length
|
856 |
+
self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
|
857 |
+
self.config = config
|
858 |
+
self.quantized = False
|
859 |
+
|
860 |
+
if self.config.quantization_bit:
|
861 |
+
self.quantize(self.config.quantization_bit, empty_init=True)
|
862 |
+
|
863 |
+
def _update_model_kwargs_for_generation(
|
864 |
+
self,
|
865 |
+
outputs: ModelOutput,
|
866 |
+
model_kwargs: Dict[str, Any],
|
867 |
+
is_encoder_decoder: bool = False,
|
868 |
+
standardize_cache_format: bool = False,
|
869 |
+
) -> Dict[str, Any]:
|
870 |
+
# update past_key_values
|
871 |
+
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
872 |
+
outputs, standardize_cache_format=standardize_cache_format
|
873 |
+
)
|
874 |
+
|
875 |
+
# update attention mask
|
876 |
+
if "attention_mask" in model_kwargs:
|
877 |
+
attention_mask = model_kwargs["attention_mask"]
|
878 |
+
model_kwargs["attention_mask"] = torch.cat(
|
879 |
+
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
880 |
+
)
|
881 |
+
|
882 |
+
# update position ids
|
883 |
+
if "position_ids" in model_kwargs:
|
884 |
+
position_ids = model_kwargs["position_ids"]
|
885 |
+
new_position_id = position_ids[..., -1:].clone()
|
886 |
+
new_position_id += 1
|
887 |
+
model_kwargs["position_ids"] = torch.cat(
|
888 |
+
[position_ids, new_position_id], dim=-1
|
889 |
+
)
|
890 |
+
|
891 |
+
model_kwargs["is_first_forward"] = False
|
892 |
+
return model_kwargs
|
893 |
+
|
894 |
+
def prepare_inputs_for_generation(
|
895 |
+
self,
|
896 |
+
input_ids: torch.LongTensor,
|
897 |
+
past_key_values: Optional[torch.Tensor] = None,
|
898 |
+
attention_mask: Optional[torch.Tensor] = None,
|
899 |
+
position_ids: Optional[torch.Tensor] = None,
|
900 |
+
use_cache: Optional[bool] = None,
|
901 |
+
is_first_forward: bool = True,
|
902 |
+
**kwargs
|
903 |
+
) -> dict:
|
904 |
+
# only last token for input_ids if past is not None
|
905 |
+
if position_ids is None:
|
906 |
+
position_ids = self.get_position_ids(input_ids, device=input_ids.device)
|
907 |
+
if not is_first_forward:
|
908 |
+
if past_key_values is not None:
|
909 |
+
position_ids = position_ids[..., -1:]
|
910 |
+
input_ids = input_ids[:, -1:]
|
911 |
+
return {
|
912 |
+
"input_ids": input_ids,
|
913 |
+
"past_key_values": past_key_values,
|
914 |
+
"position_ids": position_ids,
|
915 |
+
"attention_mask": attention_mask,
|
916 |
+
"return_last_logit": True,
|
917 |
+
"use_cache": use_cache
|
918 |
+
}
|
919 |
+
|
920 |
+
def forward(
|
921 |
+
self,
|
922 |
+
input_ids: Optional[torch.Tensor] = None,
|
923 |
+
position_ids: Optional[torch.Tensor] = None,
|
924 |
+
attention_mask: Optional[torch.Tensor] = None,
|
925 |
+
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
|
926 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
927 |
+
labels: Optional[torch.Tensor] = None,
|
928 |
+
use_cache: Optional[bool] = None,
|
929 |
+
output_attentions: Optional[bool] = None,
|
930 |
+
output_hidden_states: Optional[bool] = None,
|
931 |
+
return_dict: Optional[bool] = None,
|
932 |
+
return_last_logit: Optional[bool] = False,
|
933 |
+
):
|
934 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
935 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
936 |
+
|
937 |
+
transformer_outputs = self.transformer(
|
938 |
+
input_ids=input_ids,
|
939 |
+
position_ids=position_ids,
|
940 |
+
attention_mask=attention_mask,
|
941 |
+
past_key_values=past_key_values,
|
942 |
+
inputs_embeds=inputs_embeds,
|
943 |
+
use_cache=use_cache,
|
944 |
+
output_hidden_states=output_hidden_states,
|
945 |
+
return_dict=return_dict,
|
946 |
+
)
|
947 |
+
|
948 |
+
hidden_states = transformer_outputs[0]
|
949 |
+
if return_last_logit:
|
950 |
+
hidden_states = hidden_states[-1:]
|
951 |
+
lm_logits = self.transformer.output_layer(hidden_states)
|
952 |
+
lm_logits = lm_logits.transpose(0, 1).contiguous()
|
953 |
+
|
954 |
+
loss = None
|
955 |
+
if labels is not None:
|
956 |
+
lm_logits = lm_logits.to(torch.float32)
|
957 |
+
|
958 |
+
# Shift so that tokens < n predict n
|
959 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
960 |
+
shift_labels = labels[..., 1:].contiguous()
|
961 |
+
# Flatten the tokens
|
962 |
+
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
963 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
964 |
+
|
965 |
+
lm_logits = lm_logits.to(hidden_states.dtype)
|
966 |
+
loss = loss.to(hidden_states.dtype)
|
967 |
+
|
968 |
+
if not return_dict:
|
969 |
+
output = (lm_logits,) + transformer_outputs[1:]
|
970 |
+
return ((loss,) + output) if loss is not None else output
|
971 |
+
|
972 |
+
return CausalLMOutputWithPast(
|
973 |
+
loss=loss,
|
974 |
+
logits=lm_logits,
|
975 |
+
past_key_values=transformer_outputs.past_key_values,
|
976 |
+
hidden_states=transformer_outputs.hidden_states,
|
977 |
+
attentions=transformer_outputs.attentions,
|
978 |
+
)
|
979 |
+
|
980 |
+
@staticmethod
|
981 |
+
def _reorder_cache(
|
982 |
+
past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
|
983 |
+
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
|
984 |
+
"""
|
985 |
+
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
986 |
+
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
987 |
+
beam_idx at every generation step.
|
988 |
+
|
989 |
+
Output shares the same memory storage as `past`.
|
990 |
+
"""
|
991 |
+
return tuple(
|
992 |
+
(
|
993 |
+
layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
|
994 |
+
layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
|
995 |
+
)
|
996 |
+
for layer_past in past
|
997 |
+
)
|
998 |
+
|
999 |
+
def process_response(self, output, history):
|
1000 |
+
content = ""
|
1001 |
+
history = deepcopy(history)
|
1002 |
+
for response in output.split("<|assistant|>"):
|
1003 |
+
metadata, content = response.split("\n", maxsplit=1)
|
1004 |
+
if not metadata.strip():
|
1005 |
+
content = content.strip()
|
1006 |
+
history.append({"role": "assistant", "metadata": metadata, "content": content})
|
1007 |
+
content = content.replace("[[训练时间]]", "2023年")
|
1008 |
+
else:
|
1009 |
+
history.append({"role": "assistant", "metadata": metadata, "content": content})
|
1010 |
+
if history[0]["role"] == "system" and "tools" in history[0]:
|
1011 |
+
content = "\n".join(content.split("\n")[1:-1])
|
1012 |
+
def tool_call(**kwargs):
|
1013 |
+
return kwargs
|
1014 |
+
parameters = eval(content)
|
1015 |
+
content = {"name": metadata.strip(), "parameters": parameters}
|
1016 |
+
else:
|
1017 |
+
content = {"name": metadata.strip(), "content": content}
|
1018 |
+
return content, history
|
1019 |
+
|
1020 |
+
@torch.inference_mode()
|
1021 |
+
def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
|
1022 |
+
max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
|
1023 |
+
**kwargs):
|
1024 |
+
if history is None:
|
1025 |
+
history = []
|
1026 |
+
if logits_processor is None:
|
1027 |
+
logits_processor = LogitsProcessorList()
|
1028 |
+
logits_processor.append(InvalidScoreLogitsProcessor())
|
1029 |
+
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
1030 |
+
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
1031 |
+
inputs = tokenizer.build_chat_input(query, history=history, role=role)
|
1032 |
+
inputs = inputs.to(self.device)
|
1033 |
+
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
|
1034 |
+
tokenizer.get_command("<|observation|>")]
|
1035 |
+
outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
|
1036 |
+
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
|
1037 |
+
response = tokenizer.decode(outputs)
|
1038 |
+
history.append({"role": role, "content": query})
|
1039 |
+
response, history = self.process_response(response, history)
|
1040 |
+
return response, history
|
1041 |
+
|
1042 |
+
@torch.inference_mode()
|
1043 |
+
def stream_chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
|
1044 |
+
past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
|
1045 |
+
logits_processor=None, return_past_key_values=False, **kwargs):
|
1046 |
+
if history is None:
|
1047 |
+
history = []
|
1048 |
+
if logits_processor is None:
|
1049 |
+
logits_processor = LogitsProcessorList()
|
1050 |
+
logits_processor.append(InvalidScoreLogitsProcessor())
|
1051 |
+
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
|
1052 |
+
tokenizer.get_command("<|observation|>")]
|
1053 |
+
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
|
1054 |
+
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
1055 |
+
if past_key_values is None:
|
1056 |
+
inputs = tokenizer.build_chat_input(query, history=history, role=role)
|
1057 |
+
else:
|
1058 |
+
inputs = tokenizer.build_chat_input(query, role=role)
|
1059 |
+
inputs = inputs.to(self.device)
|
1060 |
+
if past_key_values is not None:
|
1061 |
+
past_length = past_key_values[0][0].shape[0]
|
1062 |
+
if self.transformer.pre_seq_len is not None:
|
1063 |
+
past_length -= self.transformer.pre_seq_len
|
1064 |
+
inputs.position_ids += past_length
|
1065 |
+
attention_mask = inputs.attention_mask
|
1066 |
+
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
|
1067 |
+
inputs['attention_mask'] = attention_mask
|
1068 |
+
history.append({"role": role, "content": query})
|
1069 |
+
for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
|
1070 |
+
eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
|
1071 |
+
**gen_kwargs):
|
1072 |
+
if return_past_key_values:
|
1073 |
+
outputs, past_key_values = outputs
|
1074 |
+
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
|
1075 |
+
response = tokenizer.decode(outputs)
|
1076 |
+
if response and response[-1] != "�":
|
1077 |
+
response, new_history = self.process_response(response, history)
|
1078 |
+
if return_past_key_values:
|
1079 |
+
yield response, new_history, past_key_values
|
1080 |
+
else:
|
1081 |
+
yield response, new_history
|
1082 |
+
|
1083 |
+
@torch.inference_mode()
|
1084 |
+
def stream_generate(
|
1085 |
+
self,
|
1086 |
+
input_ids,
|
1087 |
+
generation_config: Optional[GenerationConfig] = None,
|
1088 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
1089 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
1090 |
+
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
1091 |
+
return_past_key_values=False,
|
1092 |
+
**kwargs,
|
1093 |
+
):
|
1094 |
+
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
|
1095 |
+
|
1096 |
+
if generation_config is None:
|
1097 |
+
generation_config = self.generation_config
|
1098 |
+
generation_config = copy.deepcopy(generation_config)
|
1099 |
+
model_kwargs = generation_config.update(**kwargs)
|
1100 |
+
model_kwargs["use_cache"] = generation_config.use_cache
|
1101 |
+
bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
|
1102 |
+
|
1103 |
+
if isinstance(eos_token_id, int):
|
1104 |
+
eos_token_id = [eos_token_id]
|
1105 |
+
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
1106 |
+
|
1107 |
+
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
1108 |
+
if has_default_max_length and generation_config.max_new_tokens is None:
|
1109 |
+
warnings.warn(
|
1110 |
+
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
|
1111 |
+
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
|
1112 |
+
" recommend using `max_new_tokens` to control the maximum length of the generation.",
|
1113 |
+
UserWarning,
|
1114 |
+
)
|
1115 |
+
elif generation_config.max_new_tokens is not None:
|
1116 |
+
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
1117 |
+
if not has_default_max_length:
|
1118 |
+
logger.warn(
|
1119 |
+
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
1120 |
+
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
1121 |
+
"Please refer to the documentation for more information. "
|
1122 |
+
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
|
1123 |
+
UserWarning,
|
1124 |
+
)
|
1125 |
+
|
1126 |
+
if input_ids_seq_length >= generation_config.max_length:
|
1127 |
+
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
1128 |
+
logger.warning(
|
1129 |
+
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
|
1130 |
+
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
1131 |
+
" increasing `max_new_tokens`."
|
1132 |
+
)
|
1133 |
+
|
1134 |
+
# 2. Set generation parameters if not already defined
|
1135 |
+
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
1136 |
+
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
1137 |
+
|
1138 |
+
logits_processor = self._get_logits_processor(
|
1139 |
+
generation_config=generation_config,
|
1140 |
+
input_ids_seq_length=input_ids_seq_length,
|
1141 |
+
encoder_input_ids=input_ids,
|
1142 |
+
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
1143 |
+
logits_processor=logits_processor,
|
1144 |
+
)
|
1145 |
+
|
1146 |
+
stopping_criteria = self._get_stopping_criteria(
|
1147 |
+
generation_config=generation_config, stopping_criteria=stopping_criteria
|
1148 |
+
)
|
1149 |
+
logits_warper = self._get_logits_warper(generation_config)
|
1150 |
+
|
1151 |
+
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
1152 |
+
scores = None
|
1153 |
+
while True:
|
1154 |
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
1155 |
+
# forward pass to get next token
|
1156 |
+
outputs = self(
|
1157 |
+
**model_inputs,
|
1158 |
+
return_dict=True,
|
1159 |
+
output_attentions=False,
|
1160 |
+
output_hidden_states=False,
|
1161 |
+
)
|
1162 |
+
|
1163 |
+
next_token_logits = outputs.logits[:, -1, :]
|
1164 |
+
|
1165 |
+
# pre-process distribution
|
1166 |
+
next_token_scores = logits_processor(input_ids, next_token_logits)
|
1167 |
+
next_token_scores = logits_warper(input_ids, next_token_scores)
|
1168 |
+
|
1169 |
+
# sample
|
1170 |
+
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
1171 |
+
if generation_config.do_sample:
|
1172 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
1173 |
+
else:
|
1174 |
+
next_tokens = torch.argmax(probs, dim=-1)
|
1175 |
+
# update generated ids, model inputs, and length for next step
|
1176 |
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
1177 |
+
model_kwargs = self._update_model_kwargs_for_generation(
|
1178 |
+
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
1179 |
+
)
|
1180 |
+
unfinished_sequences = unfinished_sequences.mul(
|
1181 |
+
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
1182 |
+
)
|
1183 |
+
if return_past_key_values:
|
1184 |
+
yield input_ids, outputs.past_key_values
|
1185 |
+
else:
|
1186 |
+
yield input_ids
|
1187 |
+
# stop when each sentence is finished, or if we exceed the maximum length
|
1188 |
+
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
1189 |
+
break
|
1190 |
+
|
1191 |
+
def quantize(self, bits: int, empty_init=False, device=None, **kwargs):
|
1192 |
+
if bits == 0:
|
1193 |
+
return
|
1194 |
+
|
1195 |
+
from .quantization import quantize
|
1196 |
+
|
1197 |
+
if self.quantized:
|
1198 |
+
logger.info("Already quantized.")
|
1199 |
+
return self
|
1200 |
+
|
1201 |
+
self.quantized = True
|
1202 |
+
|
1203 |
+
self.config.quantization_bit = bits
|
1204 |
+
|
1205 |
+
self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device,
|
1206 |
+
**kwargs)
|
1207 |
+
return self
|
1208 |
+
|
1209 |
+
|
1210 |
+
class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
|
1211 |
+
def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
|
1212 |
+
super().__init__(config)
|
1213 |
+
|
1214 |
+
self.num_labels = config.num_labels
|
1215 |
+
self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
|
1216 |
+
|
1217 |
+
self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
|
1218 |
+
if config.classifier_dropout is not None:
|
1219 |
+
self.dropout = nn.Dropout(config.classifier_dropout)
|
1220 |
+
else:
|
1221 |
+
self.dropout = None
|
1222 |
+
self.config = config
|
1223 |
+
|
1224 |
+
if self.config.quantization_bit:
|
1225 |
+
self.quantize(self.config.quantization_bit, empty_init=True)
|
1226 |
+
|
1227 |
+
def forward(
|
1228 |
+
self,
|
1229 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1230 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1231 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1232 |
+
full_attention_mask: Optional[torch.Tensor] = None,
|
1233 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
1234 |
+
inputs_embeds: Optional[torch.LongTensor] = None,
|
1235 |
+
labels: Optional[torch.LongTensor] = None,
|
1236 |
+
use_cache: Optional[bool] = None,
|
1237 |
+
output_hidden_states: Optional[bool] = None,
|
1238 |
+
return_dict: Optional[bool] = None,
|
1239 |
+
) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
|
1240 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1241 |
+
|
1242 |
+
transformer_outputs = self.transformer(
|
1243 |
+
input_ids=input_ids,
|
1244 |
+
position_ids=position_ids,
|
1245 |
+
attention_mask=attention_mask,
|
1246 |
+
full_attention_mask=full_attention_mask,
|
1247 |
+
past_key_values=past_key_values,
|
1248 |
+
inputs_embeds=inputs_embeds,
|
1249 |
+
use_cache=use_cache,
|
1250 |
+
output_hidden_states=output_hidden_states,
|
1251 |
+
return_dict=return_dict,
|
1252 |
+
)
|
1253 |
+
|
1254 |
+
hidden_states = transformer_outputs[0]
|
1255 |
+
pooled_hidden_states = hidden_states[-1]
|
1256 |
+
if self.dropout is not None:
|
1257 |
+
pooled_hidden_states = self.dropout(pooled_hidden_states)
|
1258 |
+
logits = self.classifier_head(pooled_hidden_states)
|
1259 |
+
|
1260 |
+
loss = None
|
1261 |
+
if labels is not None:
|
1262 |
+
if self.config.problem_type is None:
|
1263 |
+
if self.num_labels == 1:
|
1264 |
+
self.config.problem_type = "regression"
|
1265 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
1266 |
+
self.config.problem_type = "single_label_classification"
|
1267 |
+
else:
|
1268 |
+
self.config.problem_type = "multi_label_classification"
|
1269 |
+
|
1270 |
+
if self.config.problem_type == "regression":
|
1271 |
+
loss_fct = MSELoss()
|
1272 |
+
if self.num_labels == 1:
|
1273 |
+
loss = loss_fct(logits.squeeze().float(), labels.squeeze())
|
1274 |
+
else:
|
1275 |
+
loss = loss_fct(logits.float(), labels)
|
1276 |
+
elif self.config.problem_type == "single_label_classification":
|
1277 |
+
loss_fct = CrossEntropyLoss()
|
1278 |
+
loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1))
|
1279 |
+
elif self.config.problem_type == "multi_label_classification":
|
1280 |
+
loss_fct = BCEWithLogitsLoss()
|
1281 |
+
loss = loss_fct(logits.float(), labels.view(-1, self.num_labels))
|
1282 |
+
|
1283 |
+
if not return_dict:
|
1284 |
+
output = (logits,) + transformer_outputs[1:]
|
1285 |
+
return ((loss,) + output) if loss is not None else output
|
1286 |
+
|
1287 |
+
return SequenceClassifierOutputWithPast(
|
1288 |
+
loss=loss,
|
1289 |
+
logits=logits,
|
1290 |
+
past_key_values=transformer_outputs.past_key_values,
|
1291 |
+
hidden_states=transformer_outputs.hidden_states,
|
1292 |
+
attentions=transformer_outputs.attentions,
|
1293 |
+
)
|
chatglm3-6b/pytorch_model-00001-of-00007.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4d5567466e89625dbd10e51c69a02982f233a10108cf232a379defdbb065ae0b
|
3 |
+
size 1827781090
|
chatglm3-6b/pytorch_model-00002-of-00007.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4ad41534016ac4a2431fa2d4b08efbe28f963744135ec52b2ea13cc86730fa2a
|
3 |
+
size 1968299480
|
chatglm3-6b/pytorch_model-00003-of-00007.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a2be9b17c332a8345e787953d4879caee4747ad4b263f013aa3c4654c94c3d24
|
3 |
+
size 1927415036
|
chatglm3-6b/pytorch_model-00004-of-00007.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b5526891e1b4c8edd2e3688df5156aa107e2140fe7e3f9d0d54f9cbe3b6ee3b5
|
3 |
+
size 1815225998
|
chatglm3-6b/pytorch_model-00005-of-00007.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:84bb18a476f74beaf4d941733bd1c475791eba799b228f78d0165de989cb7a40
|
3 |
+
size 1968299544
|
chatglm3-6b/pytorch_model-00006-of-00007.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1181875a2dc30fba968d72d0fc4628b9a60d3866bf680eb14b9822b5b504830f
|
3 |
+
size 1927415036
|
chatglm3-6b/pytorch_model-00007-of-00007.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1271b638dc0a88206d1c7a51bcaa862410eb56f3e59fd0068a96e96cb5e3f4f0
|
3 |
+
size 1052808542
|
chatglm3-6b/pytorch_model.bin.index.json
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"metadata": {
|
3 |
+
"total_size": 12487168064
|
4 |
+
},
|
5 |
+
"weight_map": {
|
6 |
+
"transformer.embedding.word_embeddings.weight": "pytorch_model-00001-of-00007.bin",
|
7 |
+
"transformer.encoder.final_layernorm.weight": "pytorch_model-00007-of-00007.bin",
|
8 |
+
"transformer.encoder.layers.0.input_layernorm.weight": "pytorch_model-00001-of-00007.bin",
|
9 |
+
"transformer.encoder.layers.0.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00007.bin",
|
10 |
+
"transformer.encoder.layers.0.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00007.bin",
|
11 |
+
"transformer.encoder.layers.0.post_attention_layernorm.weight": "pytorch_model-00001-of-00007.bin",
|
12 |
+
"transformer.encoder.layers.0.self_attention.dense.weight": "pytorch_model-00001-of-00007.bin",
|
13 |
+
"transformer.encoder.layers.0.self_attention.query_key_value.bias": "pytorch_model-00001-of-00007.bin",
|
14 |
+
"transformer.encoder.layers.0.self_attention.query_key_value.weight": "pytorch_model-00001-of-00007.bin",
|
15 |
+
"transformer.encoder.layers.1.input_layernorm.weight": "pytorch_model-00001-of-00007.bin",
|
16 |
+
"transformer.encoder.layers.1.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00007.bin",
|
17 |
+
"transformer.encoder.layers.1.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00007.bin",
|
18 |
+
"transformer.encoder.layers.1.post_attention_layernorm.weight": "pytorch_model-00001-of-00007.bin",
|
19 |
+
"transformer.encoder.layers.1.self_attention.dense.weight": "pytorch_model-00001-of-00007.bin",
|
20 |
+
"transformer.encoder.layers.1.self_attention.query_key_value.bias": "pytorch_model-00001-of-00007.bin",
|
21 |
+
"transformer.encoder.layers.1.self_attention.query_key_value.weight": "pytorch_model-00001-of-00007.bin",
|
22 |
+
"transformer.encoder.layers.10.input_layernorm.weight": "pytorch_model-00003-of-00007.bin",
|
23 |
+
"transformer.encoder.layers.10.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00007.bin",
|
24 |
+
"transformer.encoder.layers.10.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00007.bin",
|
25 |
+
"transformer.encoder.layers.10.post_attention_layernorm.weight": "pytorch_model-00003-of-00007.bin",
|
26 |
+
"transformer.encoder.layers.10.self_attention.dense.weight": "pytorch_model-00003-of-00007.bin",
|
27 |
+
"transformer.encoder.layers.10.self_attention.query_key_value.bias": "pytorch_model-00003-of-00007.bin",
|
28 |
+
"transformer.encoder.layers.10.self_attention.query_key_value.weight": "pytorch_model-00003-of-00007.bin",
|
29 |
+
"transformer.encoder.layers.11.input_layernorm.weight": "pytorch_model-00003-of-00007.bin",
|
30 |
+
"transformer.encoder.layers.11.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00007.bin",
|
31 |
+
"transformer.encoder.layers.11.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00007.bin",
|
32 |
+
"transformer.encoder.layers.11.post_attention_layernorm.weight": "pytorch_model-00003-of-00007.bin",
|
33 |
+
"transformer.encoder.layers.11.self_attention.dense.weight": "pytorch_model-00003-of-00007.bin",
|
34 |
+
"transformer.encoder.layers.11.self_attention.query_key_value.bias": "pytorch_model-00003-of-00007.bin",
|
35 |
+
"transformer.encoder.layers.11.self_attention.query_key_value.weight": "pytorch_model-00003-of-00007.bin",
|
36 |
+
"transformer.encoder.layers.12.input_layernorm.weight": "pytorch_model-00003-of-00007.bin",
|
37 |
+
"transformer.encoder.layers.12.mlp.dense_4h_to_h.weight": "pytorch_model-00004-of-00007.bin",
|
38 |
+
"transformer.encoder.layers.12.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00007.bin",
|
39 |
+
"transformer.encoder.layers.12.post_attention_layernorm.weight": "pytorch_model-00003-of-00007.bin",
|
40 |
+
"transformer.encoder.layers.12.self_attention.dense.weight": "pytorch_model-00003-of-00007.bin",
|
41 |
+
"transformer.encoder.layers.12.self_attention.query_key_value.bias": "pytorch_model-00003-of-00007.bin",
|
42 |
+
"transformer.encoder.layers.12.self_attention.query_key_value.weight": "pytorch_model-00003-of-00007.bin",
|
43 |
+
"transformer.encoder.layers.13.input_layernorm.weight": "pytorch_model-00004-of-00007.bin",
|
44 |
+
"transformer.encoder.layers.13.mlp.dense_4h_to_h.weight": "pytorch_model-00004-of-00007.bin",
|
45 |
+
"transformer.encoder.layers.13.mlp.dense_h_to_4h.weight": "pytorch_model-00004-of-00007.bin",
|
46 |
+
"transformer.encoder.layers.13.post_attention_layernorm.weight": "pytorch_model-00004-of-00007.bin",
|
47 |
+
"transformer.encoder.layers.13.self_attention.dense.weight": "pytorch_model-00004-of-00007.bin",
|
48 |
+
"transformer.encoder.layers.13.self_attention.query_key_value.bias": "pytorch_model-00004-of-00007.bin",
|
49 |
+
"transformer.encoder.layers.13.self_attention.query_key_value.weight": "pytorch_model-00004-of-00007.bin",
|
50 |
+
"transformer.encoder.layers.14.input_layernorm.weight": "pytorch_model-00004-of-00007.bin",
|
51 |
+
"transformer.encoder.layers.14.mlp.dense_4h_to_h.weight": "pytorch_model-00004-of-00007.bin",
|
52 |
+
"transformer.encoder.layers.14.mlp.dense_h_to_4h.weight": "pytorch_model-00004-of-00007.bin",
|
53 |
+
"transformer.encoder.layers.14.post_attention_layernorm.weight": "pytorch_model-00004-of-00007.bin",
|
54 |
+
"transformer.encoder.layers.14.self_attention.dense.weight": "pytorch_model-00004-of-00007.bin",
|
55 |
+
"transformer.encoder.layers.14.self_attention.query_key_value.bias": "pytorch_model-00004-of-00007.bin",
|
56 |
+
"transformer.encoder.layers.14.self_attention.query_key_value.weight": "pytorch_model-00004-of-00007.bin",
|
57 |
+
"transformer.encoder.layers.15.input_layernorm.weight": "pytorch_model-00004-of-00007.bin",
|
58 |
+
"transformer.encoder.layers.15.mlp.dense_4h_to_h.weight": "pytorch_model-00004-of-00007.bin",
|
59 |
+
"transformer.encoder.layers.15.mlp.dense_h_to_4h.weight": "pytorch_model-00004-of-00007.bin",
|
60 |
+
"transformer.encoder.layers.15.post_attention_layernorm.weight": "pytorch_model-00004-of-00007.bin",
|
61 |
+
"transformer.encoder.layers.15.self_attention.dense.weight": "pytorch_model-00004-of-00007.bin",
|
62 |
+
"transformer.encoder.layers.15.self_attention.query_key_value.bias": "pytorch_model-00004-of-00007.bin",
|
63 |
+
"transformer.encoder.layers.15.self_attention.query_key_value.weight": "pytorch_model-00004-of-00007.bin",
|
64 |
+
"transformer.encoder.layers.16.input_layernorm.weight": "pytorch_model-00004-of-00007.bin",
|
65 |
+
"transformer.encoder.layers.16.mlp.dense_4h_to_h.weight": "pytorch_model-00004-of-00007.bin",
|
66 |
+
"transformer.encoder.layers.16.mlp.dense_h_to_4h.weight": "pytorch_model-00004-of-00007.bin",
|
67 |
+
"transformer.encoder.layers.16.post_attention_layernorm.weight": "pytorch_model-00004-of-00007.bin",
|
68 |
+
"transformer.encoder.layers.16.self_attention.dense.weight": "pytorch_model-00004-of-00007.bin",
|
69 |
+
"transformer.encoder.layers.16.self_attention.query_key_value.bias": "pytorch_model-00004-of-00007.bin",
|
70 |
+
"transformer.encoder.layers.16.self_attention.query_key_value.weight": "pytorch_model-00004-of-00007.bin",
|
71 |
+
"transformer.encoder.layers.17.input_layernorm.weight": "pytorch_model-00004-of-00007.bin",
|
72 |
+
"transformer.encoder.layers.17.mlp.dense_4h_to_h.weight": "pytorch_model-00005-of-00007.bin",
|
73 |
+
"transformer.encoder.layers.17.mlp.dense_h_to_4h.weight": "pytorch_model-00005-of-00007.bin",
|
74 |
+
"transformer.encoder.layers.17.post_attention_layernorm.weight": "pytorch_model-00004-of-00007.bin",
|
75 |
+
"transformer.encoder.layers.17.self_attention.dense.weight": "pytorch_model-00004-of-00007.bin",
|
76 |
+
"transformer.encoder.layers.17.self_attention.query_key_value.bias": "pytorch_model-00004-of-00007.bin",
|
77 |
+
"transformer.encoder.layers.17.self_attention.query_key_value.weight": "pytorch_model-00004-of-00007.bin",
|
78 |
+
"transformer.encoder.layers.18.input_layernorm.weight": "pytorch_model-00005-of-00007.bin",
|
79 |
+
"transformer.encoder.layers.18.mlp.dense_4h_to_h.weight": "pytorch_model-00005-of-00007.bin",
|
80 |
+
"transformer.encoder.layers.18.mlp.dense_h_to_4h.weight": "pytorch_model-00005-of-00007.bin",
|
81 |
+
"transformer.encoder.layers.18.post_attention_layernorm.weight": "pytorch_model-00005-of-00007.bin",
|
82 |
+
"transformer.encoder.layers.18.self_attention.dense.weight": "pytorch_model-00005-of-00007.bin",
|
83 |
+
"transformer.encoder.layers.18.self_attention.query_key_value.bias": "pytorch_model-00005-of-00007.bin",
|
84 |
+
"transformer.encoder.layers.18.self_attention.query_key_value.weight": "pytorch_model-00005-of-00007.bin",
|
85 |
+
"transformer.encoder.layers.19.input_layernorm.weight": "pytorch_model-00005-of-00007.bin",
|
86 |
+
"transformer.encoder.layers.19.mlp.dense_4h_to_h.weight": "pytorch_model-00005-of-00007.bin",
|
87 |
+
"transformer.encoder.layers.19.mlp.dense_h_to_4h.weight": "pytorch_model-00005-of-00007.bin",
|
88 |
+
"transformer.encoder.layers.19.post_attention_layernorm.weight": "pytorch_model-00005-of-00007.bin",
|
89 |
+
"transformer.encoder.layers.19.self_attention.dense.weight": "pytorch_model-00005-of-00007.bin",
|
90 |
+
"transformer.encoder.layers.19.self_attention.query_key_value.bias": "pytorch_model-00005-of-00007.bin",
|
91 |
+
"transformer.encoder.layers.19.self_attention.query_key_value.weight": "pytorch_model-00005-of-00007.bin",
|
92 |
+
"transformer.encoder.layers.2.input_layernorm.weight": "pytorch_model-00001-of-00007.bin",
|
93 |
+
"transformer.encoder.layers.2.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00007.bin",
|
94 |
+
"transformer.encoder.layers.2.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00007.bin",
|
95 |
+
"transformer.encoder.layers.2.post_attention_layernorm.weight": "pytorch_model-00001-of-00007.bin",
|
96 |
+
"transformer.encoder.layers.2.self_attention.dense.weight": "pytorch_model-00001-of-00007.bin",
|
97 |
+
"transformer.encoder.layers.2.self_attention.query_key_value.bias": "pytorch_model-00001-of-00007.bin",
|
98 |
+
"transformer.encoder.layers.2.self_attention.query_key_value.weight": "pytorch_model-00001-of-00007.bin",
|
99 |
+
"transformer.encoder.layers.20.input_layernorm.weight": "pytorch_model-00005-of-00007.bin",
|
100 |
+
"transformer.encoder.layers.20.mlp.dense_4h_to_h.weight": "pytorch_model-00005-of-00007.bin",
|
101 |
+
"transformer.encoder.layers.20.mlp.dense_h_to_4h.weight": "pytorch_model-00005-of-00007.bin",
|
102 |
+
"transformer.encoder.layers.20.post_attention_layernorm.weight": "pytorch_model-00005-of-00007.bin",
|
103 |
+
"transformer.encoder.layers.20.self_attention.dense.weight": "pytorch_model-00005-of-00007.bin",
|
104 |
+
"transformer.encoder.layers.20.self_attention.query_key_value.bias": "pytorch_model-00005-of-00007.bin",
|
105 |
+
"transformer.encoder.layers.20.self_attention.query_key_value.weight": "pytorch_model-00005-of-00007.bin",
|
106 |
+
"transformer.encoder.layers.21.input_layernorm.weight": "pytorch_model-00005-of-00007.bin",
|
107 |
+
"transformer.encoder.layers.21.mlp.dense_4h_to_h.weight": "pytorch_model-00005-of-00007.bin",
|
108 |
+
"transformer.encoder.layers.21.mlp.dense_h_to_4h.weight": "pytorch_model-00005-of-00007.bin",
|
109 |
+
"transformer.encoder.layers.21.post_attention_layernorm.weight": "pytorch_model-00005-of-00007.bin",
|
110 |
+
"transformer.encoder.layers.21.self_attention.dense.weight": "pytorch_model-00005-of-00007.bin",
|
111 |
+
"transformer.encoder.layers.21.self_attention.query_key_value.bias": "pytorch_model-00005-of-00007.bin",
|
112 |
+
"transformer.encoder.layers.21.self_attention.query_key_value.weight": "pytorch_model-00005-of-00007.bin",
|
113 |
+
"transformer.encoder.layers.22.input_layernorm.weight": "pytorch_model-00005-of-00007.bin",
|
114 |
+
"transformer.encoder.layers.22.mlp.dense_4h_to_h.weight": "pytorch_model-00006-of-00007.bin",
|
115 |
+
"transformer.encoder.layers.22.mlp.dense_h_to_4h.weight": "pytorch_model-00006-of-00007.bin",
|
116 |
+
"transformer.encoder.layers.22.post_attention_layernorm.weight": "pytorch_model-00006-of-00007.bin",
|
117 |
+
"transformer.encoder.layers.22.self_attention.dense.weight": "pytorch_model-00006-of-00007.bin",
|
118 |
+
"transformer.encoder.layers.22.self_attention.query_key_value.bias": "pytorch_model-00006-of-00007.bin",
|
119 |
+
"transformer.encoder.layers.22.self_attention.query_key_value.weight": "pytorch_model-00006-of-00007.bin",
|
120 |
+
"transformer.encoder.layers.23.input_layernorm.weight": "pytorch_model-00006-of-00007.bin",
|
121 |
+
"transformer.encoder.layers.23.mlp.dense_4h_to_h.weight": "pytorch_model-00006-of-00007.bin",
|
122 |
+
"transformer.encoder.layers.23.mlp.dense_h_to_4h.weight": "pytorch_model-00006-of-00007.bin",
|
123 |
+
"transformer.encoder.layers.23.post_attention_layernorm.weight": "pytorch_model-00006-of-00007.bin",
|
124 |
+
"transformer.encoder.layers.23.self_attention.dense.weight": "pytorch_model-00006-of-00007.bin",
|
125 |
+
"transformer.encoder.layers.23.self_attention.query_key_value.bias": "pytorch_model-00006-of-00007.bin",
|
126 |
+
"transformer.encoder.layers.23.self_attention.query_key_value.weight": "pytorch_model-00006-of-00007.bin",
|
127 |
+
"transformer.encoder.layers.24.input_layernorm.weight": "pytorch_model-00006-of-00007.bin",
|
128 |
+
"transformer.encoder.layers.24.mlp.dense_4h_to_h.weight": "pytorch_model-00006-of-00007.bin",
|
129 |
+
"transformer.encoder.layers.24.mlp.dense_h_to_4h.weight": "pytorch_model-00006-of-00007.bin",
|
130 |
+
"transformer.encoder.layers.24.post_attention_layernorm.weight": "pytorch_model-00006-of-00007.bin",
|
131 |
+
"transformer.encoder.layers.24.self_attention.dense.weight": "pytorch_model-00006-of-00007.bin",
|
132 |
+
"transformer.encoder.layers.24.self_attention.query_key_value.bias": "pytorch_model-00006-of-00007.bin",
|
133 |
+
"transformer.encoder.layers.24.self_attention.query_key_value.weight": "pytorch_model-00006-of-00007.bin",
|
134 |
+
"transformer.encoder.layers.25.input_layernorm.weight": "pytorch_model-00006-of-00007.bin",
|
135 |
+
"transformer.encoder.layers.25.mlp.dense_4h_to_h.weight": "pytorch_model-00006-of-00007.bin",
|
136 |
+
"transformer.encoder.layers.25.mlp.dense_h_to_4h.weight": "pytorch_model-00006-of-00007.bin",
|
137 |
+
"transformer.encoder.layers.25.post_attention_layernorm.weight": "pytorch_model-00006-of-00007.bin",
|
138 |
+
"transformer.encoder.layers.25.self_attention.dense.weight": "pytorch_model-00006-of-00007.bin",
|
139 |
+
"transformer.encoder.layers.25.self_attention.query_key_value.bias": "pytorch_model-00006-of-00007.bin",
|
140 |
+
"transformer.encoder.layers.25.self_attention.query_key_value.weight": "pytorch_model-00006-of-00007.bin",
|
141 |
+
"transformer.encoder.layers.26.input_layernorm.weight": "pytorch_model-00006-of-00007.bin",
|
142 |
+
"transformer.encoder.layers.26.mlp.dense_4h_to_h.weight": "pytorch_model-00007-of-00007.bin",
|
143 |
+
"transformer.encoder.layers.26.mlp.dense_h_to_4h.weight": "pytorch_model-00006-of-00007.bin",
|
144 |
+
"transformer.encoder.layers.26.post_attention_layernorm.weight": "pytorch_model-00006-of-00007.bin",
|
145 |
+
"transformer.encoder.layers.26.self_attention.dense.weight": "pytorch_model-00006-of-00007.bin",
|
146 |
+
"transformer.encoder.layers.26.self_attention.query_key_value.bias": "pytorch_model-00006-of-00007.bin",
|
147 |
+
"transformer.encoder.layers.26.self_attention.query_key_value.weight": "pytorch_model-00006-of-00007.bin",
|
148 |
+
"transformer.encoder.layers.27.input_layernorm.weight": "pytorch_model-00007-of-00007.bin",
|
149 |
+
"transformer.encoder.layers.27.mlp.dense_4h_to_h.weight": "pytorch_model-00007-of-00007.bin",
|
150 |
+
"transformer.encoder.layers.27.mlp.dense_h_to_4h.weight": "pytorch_model-00007-of-00007.bin",
|
151 |
+
"transformer.encoder.layers.27.post_attention_layernorm.weight": "pytorch_model-00007-of-00007.bin",
|
152 |
+
"transformer.encoder.layers.27.self_attention.dense.weight": "pytorch_model-00007-of-00007.bin",
|
153 |
+
"transformer.encoder.layers.27.self_attention.query_key_value.bias": "pytorch_model-00007-of-00007.bin",
|
154 |
+
"transformer.encoder.layers.27.self_attention.query_key_value.weight": "pytorch_model-00007-of-00007.bin",
|
155 |
+
"transformer.encoder.layers.3.input_layernorm.weight": "pytorch_model-00001-of-00007.bin",
|
156 |
+
"transformer.encoder.layers.3.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00007.bin",
|
157 |
+
"transformer.encoder.layers.3.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00007.bin",
|
158 |
+
"transformer.encoder.layers.3.post_attention_layernorm.weight": "pytorch_model-00001-of-00007.bin",
|
159 |
+
"transformer.encoder.layers.3.self_attention.dense.weight": "pytorch_model-00001-of-00007.bin",
|
160 |
+
"transformer.encoder.layers.3.self_attention.query_key_value.bias": "pytorch_model-00001-of-00007.bin",
|
161 |
+
"transformer.encoder.layers.3.self_attention.query_key_value.weight": "pytorch_model-00001-of-00007.bin",
|
162 |
+
"transformer.encoder.layers.4.input_layernorm.weight": "pytorch_model-00002-of-00007.bin",
|
163 |
+
"transformer.encoder.layers.4.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00007.bin",
|
164 |
+
"transformer.encoder.layers.4.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00007.bin",
|
165 |
+
"transformer.encoder.layers.4.post_attention_layernorm.weight": "pytorch_model-00002-of-00007.bin",
|
166 |
+
"transformer.encoder.layers.4.self_attention.dense.weight": "pytorch_model-00002-of-00007.bin",
|
167 |
+
"transformer.encoder.layers.4.self_attention.query_key_value.bias": "pytorch_model-00002-of-00007.bin",
|
168 |
+
"transformer.encoder.layers.4.self_attention.query_key_value.weight": "pytorch_model-00002-of-00007.bin",
|
169 |
+
"transformer.encoder.layers.5.input_layernorm.weight": "pytorch_model-00002-of-00007.bin",
|
170 |
+
"transformer.encoder.layers.5.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00007.bin",
|
171 |
+
"transformer.encoder.layers.5.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00007.bin",
|
172 |
+
"transformer.encoder.layers.5.post_attention_layernorm.weight": "pytorch_model-00002-of-00007.bin",
|
173 |
+
"transformer.encoder.layers.5.self_attention.dense.weight": "pytorch_model-00002-of-00007.bin",
|
174 |
+
"transformer.encoder.layers.5.self_attention.query_key_value.bias": "pytorch_model-00002-of-00007.bin",
|
175 |
+
"transformer.encoder.layers.5.self_attention.query_key_value.weight": "pytorch_model-00002-of-00007.bin",
|
176 |
+
"transformer.encoder.layers.6.input_layernorm.weight": "pytorch_model-00002-of-00007.bin",
|
177 |
+
"transformer.encoder.layers.6.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00007.bin",
|
178 |
+
"transformer.encoder.layers.6.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00007.bin",
|
179 |
+
"transformer.encoder.layers.6.post_attention_layernorm.weight": "pytorch_model-00002-of-00007.bin",
|
180 |
+
"transformer.encoder.layers.6.self_attention.dense.weight": "pytorch_model-00002-of-00007.bin",
|
181 |
+
"transformer.encoder.layers.6.self_attention.query_key_value.bias": "pytorch_model-00002-of-00007.bin",
|
182 |
+
"transformer.encoder.layers.6.self_attention.query_key_value.weight": "pytorch_model-00002-of-00007.bin",
|
183 |
+
"transformer.encoder.layers.7.input_layernorm.weight": "pytorch_model-00002-of-00007.bin",
|
184 |
+
"transformer.encoder.layers.7.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00007.bin",
|
185 |
+
"transformer.encoder.layers.7.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00007.bin",
|
186 |
+
"transformer.encoder.layers.7.post_attention_layernorm.weight": "pytorch_model-00002-of-00007.bin",
|
187 |
+
"transformer.encoder.layers.7.self_attention.dense.weight": "pytorch_model-00002-of-00007.bin",
|
188 |
+
"transformer.encoder.layers.7.self_attention.query_key_value.bias": "pytorch_model-00002-of-00007.bin",
|
189 |
+
"transformer.encoder.layers.7.self_attention.query_key_value.weight": "pytorch_model-00002-of-00007.bin",
|
190 |
+
"transformer.encoder.layers.8.input_layernorm.weight": "pytorch_model-00002-of-00007.bin",
|
191 |
+
"transformer.encoder.layers.8.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00007.bin",
|
192 |
+
"transformer.encoder.layers.8.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00007.bin",
|
193 |
+
"transformer.encoder.layers.8.post_attention_layernorm.weight": "pytorch_model-00003-of-00007.bin",
|
194 |
+
"transformer.encoder.layers.8.self_attention.dense.weight": "pytorch_model-00003-of-00007.bin",
|
195 |
+
"transformer.encoder.layers.8.self_attention.query_key_value.bias": "pytorch_model-00003-of-00007.bin",
|
196 |
+
"transformer.encoder.layers.8.self_attention.query_key_value.weight": "pytorch_model-00003-of-00007.bin",
|
197 |
+
"transformer.encoder.layers.9.input_layernorm.weight": "pytorch_model-00003-of-00007.bin",
|
198 |
+
"transformer.encoder.layers.9.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00007.bin",
|
199 |
+
"transformer.encoder.layers.9.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00007.bin",
|
200 |
+
"transformer.encoder.layers.9.post_attention_layernorm.weight": "pytorch_model-00003-of-00007.bin",
|
201 |
+
"transformer.encoder.layers.9.self_attention.dense.weight": "pytorch_model-00003-of-00007.bin",
|
202 |
+
"transformer.encoder.layers.9.self_attention.query_key_value.bias": "pytorch_model-00003-of-00007.bin",
|
203 |
+
"transformer.encoder.layers.9.self_attention.query_key_value.weight": "pytorch_model-00003-of-00007.bin",
|
204 |
+
"transformer.output_layer.weight": "pytorch_model-00007-of-00007.bin",
|
205 |
+
"transformer.rotary_pos_emb.inv_freq": "pytorch_model-00001-of-00007.bin"
|
206 |
+
}
|
207 |
+
}
|
chatglm3-6b/quantization.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.nn import Linear
|
2 |
+
from torch.nn.parameter import Parameter
|
3 |
+
|
4 |
+
import bz2
|
5 |
+
import torch
|
6 |
+
import base64
|
7 |
+
import ctypes
|
8 |
+
from transformers.utils import logging
|
9 |
+
|
10 |
+
from typing import List
|
11 |
+
from functools import partial
|
12 |
+
|
13 |
+
logger = logging.get_logger(__name__)
|
14 |
+
|
15 |
+
try:
|
16 |
+
from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
|
17 |
+
|
18 |
+
class Kernel:
|
19 |
+
def __init__(self, code: bytes, function_names: List[str]):
|
20 |
+
self.code = code
|
21 |
+
self._function_names = function_names
|
22 |
+
self._cmodule = LazyKernelCModule(self.code)
|
23 |
+
|
24 |
+
for name in self._function_names:
|
25 |
+
setattr(self, name, KernelFunction(self._cmodule, name))
|
26 |
+
|
27 |
+
quantization_code = "$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ"
|
28 |
+
|
29 |
+
kernels = Kernel(
|
30 |
+
bz2.decompress(base64.b64decode(quantization_code)),
|
31 |
+
[
|
32 |
+
"int4WeightCompression",
|
33 |
+
"int4WeightExtractionFloat",
|
34 |
+
"int4WeightExtractionHalf",
|
35 |
+
"int8WeightExtractionFloat",
|
36 |
+
"int8WeightExtractionHalf",
|
37 |
+
],
|
38 |
+
)
|
39 |
+
except Exception as exception:
|
40 |
+
kernels = None
|
41 |
+
logger.warning("Failed to load cpm_kernels:" + str(exception))
|
42 |
+
|
43 |
+
|
44 |
+
class W8A16Linear(torch.autograd.Function):
|
45 |
+
@staticmethod
|
46 |
+
def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
|
47 |
+
ctx.inp_shape = inp.size()
|
48 |
+
ctx.weight_bit_width = weight_bit_width
|
49 |
+
out_features = quant_w.size(0)
|
50 |
+
inp = inp.contiguous().view(-1, inp.size(-1))
|
51 |
+
weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)
|
52 |
+
ctx.weight_shape = weight.size()
|
53 |
+
output = inp.mm(weight.t())
|
54 |
+
ctx.save_for_backward(inp, quant_w, scale_w)
|
55 |
+
return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
|
56 |
+
|
57 |
+
@staticmethod
|
58 |
+
def backward(ctx, grad_output: torch.Tensor):
|
59 |
+
inp, quant_w, scale_w = ctx.saved_tensors
|
60 |
+
weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width)
|
61 |
+
grad_output = grad_output.contiguous().view(-1, weight.size(0))
|
62 |
+
grad_input = grad_output.mm(weight)
|
63 |
+
grad_weight = grad_output.t().mm(inp)
|
64 |
+
return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None
|
65 |
+
|
66 |
+
|
67 |
+
def compress_int4_weight(weight: torch.Tensor): # (n, m)
|
68 |
+
with torch.cuda.device(weight.device):
|
69 |
+
n, m = weight.size(0), weight.size(1)
|
70 |
+
assert m % 2 == 0
|
71 |
+
m = m // 2
|
72 |
+
out = torch.empty(n, m, dtype=torch.int8, device="cuda")
|
73 |
+
stream = torch.cuda.current_stream()
|
74 |
+
|
75 |
+
gridDim = (n, 1, 1)
|
76 |
+
blockDim = (min(round_up(m, 32), 1024), 1, 1)
|
77 |
+
|
78 |
+
kernels.int4WeightCompression(
|
79 |
+
gridDim,
|
80 |
+
blockDim,
|
81 |
+
0,
|
82 |
+
stream,
|
83 |
+
[ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)],
|
84 |
+
)
|
85 |
+
return out
|
86 |
+
|
87 |
+
|
88 |
+
def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int):
|
89 |
+
assert scale_list.dtype in [torch.half, torch.bfloat16]
|
90 |
+
assert weight.dtype in [torch.int8]
|
91 |
+
if source_bit_width == 8:
|
92 |
+
return weight.to(scale_list.dtype) * scale_list[:, None]
|
93 |
+
elif source_bit_width == 4:
|
94 |
+
func = (
|
95 |
+
kernels.int4WeightExtractionHalf if scale_list.dtype == torch.half else kernels.int4WeightExtractionBFloat16
|
96 |
+
)
|
97 |
+
else:
|
98 |
+
assert False, "Unsupported bit-width"
|
99 |
+
|
100 |
+
with torch.cuda.device(weight.device):
|
101 |
+
n, m = weight.size(0), weight.size(1)
|
102 |
+
out = torch.empty(n, m * (8 // source_bit_width), dtype=scale_list.dtype, device="cuda")
|
103 |
+
stream = torch.cuda.current_stream()
|
104 |
+
|
105 |
+
gridDim = (n, 1, 1)
|
106 |
+
blockDim = (min(round_up(m, 32), 1024), 1, 1)
|
107 |
+
|
108 |
+
func(
|
109 |
+
gridDim,
|
110 |
+
blockDim,
|
111 |
+
0,
|
112 |
+
stream,
|
113 |
+
[
|
114 |
+
ctypes.c_void_p(weight.data_ptr()),
|
115 |
+
ctypes.c_void_p(scale_list.data_ptr()),
|
116 |
+
ctypes.c_void_p(out.data_ptr()),
|
117 |
+
ctypes.c_int32(n),
|
118 |
+
ctypes.c_int32(m),
|
119 |
+
],
|
120 |
+
)
|
121 |
+
return out
|
122 |
+
|
123 |
+
|
124 |
+
class QuantizedLinear(torch.nn.Module):
|
125 |
+
def __init__(self, weight_bit_width: int, weight, bias=None, device="cpu", dtype=None, empty_init=False, *args,
|
126 |
+
**kwargs):
|
127 |
+
super().__init__()
|
128 |
+
self.weight_bit_width = weight_bit_width
|
129 |
+
|
130 |
+
shape = weight.shape
|
131 |
+
|
132 |
+
if weight is None or empty_init:
|
133 |
+
self.weight = torch.empty(shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=device)
|
134 |
+
self.weight_scale = torch.empty(shape[0], dtype=dtype, device=device)
|
135 |
+
else:
|
136 |
+
self.weight_scale = weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)
|
137 |
+
self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
|
138 |
+
if weight_bit_width == 4:
|
139 |
+
self.weight = compress_int4_weight(self.weight)
|
140 |
+
|
141 |
+
self.weight = Parameter(self.weight.to(device), requires_grad=False)
|
142 |
+
self.weight_scale = Parameter(self.weight_scale.to(device), requires_grad=False)
|
143 |
+
self.bias = Parameter(bias.to(device), requires_grad=False) if bias is not None else None
|
144 |
+
|
145 |
+
def forward(self, input):
|
146 |
+
output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width)
|
147 |
+
if self.bias is not None:
|
148 |
+
output = output + self.bias
|
149 |
+
return output
|
150 |
+
|
151 |
+
|
152 |
+
def quantize(model, weight_bit_width, empty_init=False, device=None):
|
153 |
+
"""Replace fp16 linear with quantized linear"""
|
154 |
+
for layer in model.layers:
|
155 |
+
layer.self_attention.query_key_value = QuantizedLinear(
|
156 |
+
weight_bit_width=weight_bit_width,
|
157 |
+
weight=layer.self_attention.query_key_value.weight.to(torch.cuda.current_device()),
|
158 |
+
bias=layer.self_attention.query_key_value.bias,
|
159 |
+
dtype=layer.self_attention.query_key_value.weight.dtype,
|
160 |
+
device=layer.self_attention.query_key_value.weight.device if device is None else device,
|
161 |
+
empty_init=empty_init
|
162 |
+
)
|
163 |
+
layer.self_attention.dense = QuantizedLinear(
|
164 |
+
weight_bit_width=weight_bit_width,
|
165 |
+
weight=layer.self_attention.dense.weight.to(torch.cuda.current_device()),
|
166 |
+
bias=layer.self_attention.dense.bias,
|
167 |
+
dtype=layer.self_attention.dense.weight.dtype,
|
168 |
+
device=layer.self_attention.dense.weight.device if device is None else device,
|
169 |
+
empty_init=empty_init
|
170 |
+
)
|
171 |
+
layer.mlp.dense_h_to_4h = QuantizedLinear(
|
172 |
+
weight_bit_width=weight_bit_width,
|
173 |
+
weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()),
|
174 |
+
bias=layer.mlp.dense_h_to_4h.bias,
|
175 |
+
dtype=layer.mlp.dense_h_to_4h.weight.dtype,
|
176 |
+
device=layer.mlp.dense_h_to_4h.weight.device if device is None else device,
|
177 |
+
empty_init=empty_init
|
178 |
+
)
|
179 |
+
layer.mlp.dense_4h_to_h = QuantizedLinear(
|
180 |
+
weight_bit_width=weight_bit_width,
|
181 |
+
weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()),
|
182 |
+
bias=layer.mlp.dense_4h_to_h.bias,
|
183 |
+
dtype=layer.mlp.dense_4h_to_h.weight.dtype,
|
184 |
+
device=layer.mlp.dense_4h_to_h.weight.device if device is None else device,
|
185 |
+
empty_init=empty_init
|
186 |
+
)
|
187 |
+
|
188 |
+
return model
|
chatglm3-6b/tokenization_chatglm.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
from typing import List, Optional, Union, Dict
|
5 |
+
from sentencepiece import SentencePieceProcessor
|
6 |
+
from transformers import PreTrainedTokenizer
|
7 |
+
from transformers.utils import logging, PaddingStrategy
|
8 |
+
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
|
9 |
+
|
10 |
+
|
11 |
+
class SPTokenizer:
|
12 |
+
def __init__(self, model_path: str):
|
13 |
+
# reload tokenizer
|
14 |
+
assert os.path.isfile(model_path), model_path
|
15 |
+
self.sp_model = SentencePieceProcessor(model_file=model_path)
|
16 |
+
|
17 |
+
# BOS / EOS token IDs
|
18 |
+
self.n_words: int = self.sp_model.vocab_size()
|
19 |
+
self.bos_id: int = self.sp_model.bos_id()
|
20 |
+
self.eos_id: int = self.sp_model.eos_id()
|
21 |
+
self.pad_id: int = self.sp_model.unk_id()
|
22 |
+
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
|
23 |
+
|
24 |
+
role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
|
25 |
+
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
|
26 |
+
self.special_tokens = {}
|
27 |
+
self.index_special_tokens = {}
|
28 |
+
for token in special_tokens:
|
29 |
+
self.special_tokens[token] = self.n_words
|
30 |
+
self.index_special_tokens[self.n_words] = token
|
31 |
+
self.n_words += 1
|
32 |
+
self.role_special_token_expression = "|".join([re.escape(token) for token in role_special_tokens])
|
33 |
+
|
34 |
+
def tokenize(self, s: str, encode_special_tokens=False):
|
35 |
+
if encode_special_tokens:
|
36 |
+
last_index = 0
|
37 |
+
t = []
|
38 |
+
for match in re.finditer(self.role_special_token_expression, s):
|
39 |
+
if last_index < match.start():
|
40 |
+
t.extend(self.sp_model.EncodeAsPieces(s[last_index:match.start()]))
|
41 |
+
t.append(s[match.start():match.end()])
|
42 |
+
last_index = match.end()
|
43 |
+
if last_index < len(s):
|
44 |
+
t.extend(self.sp_model.EncodeAsPieces(s[last_index:]))
|
45 |
+
return t
|
46 |
+
else:
|
47 |
+
return self.sp_model.EncodeAsPieces(s)
|
48 |
+
|
49 |
+
def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
|
50 |
+
assert type(s) is str
|
51 |
+
t = self.sp_model.encode(s)
|
52 |
+
if bos:
|
53 |
+
t = [self.bos_id] + t
|
54 |
+
if eos:
|
55 |
+
t = t + [self.eos_id]
|
56 |
+
return t
|
57 |
+
|
58 |
+
def decode(self, t: List[int]) -> str:
|
59 |
+
text, buffer = "", []
|
60 |
+
for token in t:
|
61 |
+
if token in self.index_special_tokens:
|
62 |
+
if buffer:
|
63 |
+
text += self.sp_model.decode(buffer)
|
64 |
+
buffer = []
|
65 |
+
text += self.index_special_tokens[token]
|
66 |
+
else:
|
67 |
+
buffer.append(token)
|
68 |
+
if buffer:
|
69 |
+
text += self.sp_model.decode(buffer)
|
70 |
+
return text
|
71 |
+
|
72 |
+
def decode_tokens(self, tokens: List[str]) -> str:
|
73 |
+
text = self.sp_model.DecodePieces(tokens)
|
74 |
+
return text
|
75 |
+
|
76 |
+
def convert_token_to_id(self, token):
|
77 |
+
""" Converts a token (str) in an id using the vocab. """
|
78 |
+
if token in self.special_tokens:
|
79 |
+
return self.special_tokens[token]
|
80 |
+
return self.sp_model.PieceToId(token)
|
81 |
+
|
82 |
+
def convert_id_to_token(self, index):
|
83 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
84 |
+
if index in self.index_special_tokens:
|
85 |
+
return self.index_special_tokens[index]
|
86 |
+
if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
|
87 |
+
return ""
|
88 |
+
return self.sp_model.IdToPiece(index)
|
89 |
+
|
90 |
+
|
91 |
+
class ChatGLMTokenizer(PreTrainedTokenizer):
|
92 |
+
vocab_files_names = {"vocab_file": "tokenizer.model"}
|
93 |
+
|
94 |
+
model_input_names = ["input_ids", "attention_mask", "position_ids"]
|
95 |
+
|
96 |
+
def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, encode_special_tokens=False,
|
97 |
+
**kwargs):
|
98 |
+
self.name = "GLMTokenizer"
|
99 |
+
|
100 |
+
self.vocab_file = vocab_file
|
101 |
+
self.tokenizer = SPTokenizer(vocab_file)
|
102 |
+
self.special_tokens = {
|
103 |
+
"<bos>": self.tokenizer.bos_id,
|
104 |
+
"<eos>": self.tokenizer.eos_id,
|
105 |
+
"<pad>": self.tokenizer.pad_id
|
106 |
+
}
|
107 |
+
self.encode_special_tokens = encode_special_tokens
|
108 |
+
super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
109 |
+
encode_special_tokens=encode_special_tokens,
|
110 |
+
**kwargs)
|
111 |
+
|
112 |
+
def get_command(self, token):
|
113 |
+
if token in self.special_tokens:
|
114 |
+
return self.special_tokens[token]
|
115 |
+
assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
|
116 |
+
return self.tokenizer.special_tokens[token]
|
117 |
+
|
118 |
+
@property
|
119 |
+
def unk_token(self) -> str:
|
120 |
+
return "<unk>"
|
121 |
+
|
122 |
+
@property
|
123 |
+
def pad_token(self) -> str:
|
124 |
+
return "<unk>"
|
125 |
+
|
126 |
+
@property
|
127 |
+
def pad_token_id(self):
|
128 |
+
return self.get_command("<pad>")
|
129 |
+
|
130 |
+
@property
|
131 |
+
def eos_token(self) -> str:
|
132 |
+
return "</s>"
|
133 |
+
|
134 |
+
@property
|
135 |
+
def eos_token_id(self):
|
136 |
+
return self.get_command("<eos>")
|
137 |
+
|
138 |
+
@property
|
139 |
+
def vocab_size(self):
|
140 |
+
return self.tokenizer.n_words
|
141 |
+
|
142 |
+
def get_vocab(self):
|
143 |
+
""" Returns vocab as a dict """
|
144 |
+
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
|
145 |
+
vocab.update(self.added_tokens_encoder)
|
146 |
+
return vocab
|
147 |
+
|
148 |
+
def _tokenize(self, text, **kwargs):
|
149 |
+
return self.tokenizer.tokenize(text, encode_special_tokens=self.encode_special_tokens)
|
150 |
+
|
151 |
+
def _convert_token_to_id(self, token):
|
152 |
+
""" Converts a token (str) in an id using the vocab. """
|
153 |
+
return self.tokenizer.convert_token_to_id(token)
|
154 |
+
|
155 |
+
def _convert_id_to_token(self, index):
|
156 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
157 |
+
return self.tokenizer.convert_id_to_token(index)
|
158 |
+
|
159 |
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
160 |
+
return self.tokenizer.decode_tokens(tokens)
|
161 |
+
|
162 |
+
def save_vocabulary(self, save_directory, filename_prefix=None):
|
163 |
+
"""
|
164 |
+
Save the vocabulary and special tokens file to a directory.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
save_directory (`str`):
|
168 |
+
The directory in which to save the vocabulary.
|
169 |
+
filename_prefix (`str`, *optional*):
|
170 |
+
An optional prefix to add to the named of the saved files.
|
171 |
+
|
172 |
+
Returns:
|
173 |
+
`Tuple(str)`: Paths to the files saved.
|
174 |
+
"""
|
175 |
+
if os.path.isdir(save_directory):
|
176 |
+
vocab_file = os.path.join(
|
177 |
+
save_directory, self.vocab_files_names["vocab_file"]
|
178 |
+
)
|
179 |
+
else:
|
180 |
+
vocab_file = save_directory
|
181 |
+
|
182 |
+
with open(self.vocab_file, 'rb') as fin:
|
183 |
+
proto_str = fin.read()
|
184 |
+
|
185 |
+
with open(vocab_file, "wb") as writer:
|
186 |
+
writer.write(proto_str)
|
187 |
+
|
188 |
+
return (vocab_file,)
|
189 |
+
|
190 |
+
def get_prefix_tokens(self):
|
191 |
+
prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
|
192 |
+
return prefix_tokens
|
193 |
+
|
194 |
+
def build_single_message(self, role, metadata, message):
|
195 |
+
assert role in ["system", "user", "assistant", "observation"], role
|
196 |
+
role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
|
197 |
+
message_tokens = self.tokenizer.encode(message)
|
198 |
+
tokens = role_tokens + message_tokens
|
199 |
+
return tokens
|
200 |
+
|
201 |
+
def build_chat_input(self, query, history=None, role="user"):
|
202 |
+
if history is None:
|
203 |
+
history = []
|
204 |
+
input_ids = []
|
205 |
+
for item in history:
|
206 |
+
content = item["content"]
|
207 |
+
if item["role"] == "system" and "tools" in item:
|
208 |
+
content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
|
209 |
+
input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
|
210 |
+
input_ids.extend(self.build_single_message(role, "", query))
|
211 |
+
input_ids.extend([self.get_command("<|assistant|>")])
|
212 |
+
return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
|
213 |
+
|
214 |
+
def build_inputs_with_special_tokens(
|
215 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
216 |
+
) -> List[int]:
|
217 |
+
"""
|
218 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
219 |
+
adding special tokens. A BERT sequence has the following format:
|
220 |
+
|
221 |
+
- single sequence: `[CLS] X [SEP]`
|
222 |
+
- pair of sequences: `[CLS] A [SEP] B [SEP]`
|
223 |
+
|
224 |
+
Args:
|
225 |
+
token_ids_0 (`List[int]`):
|
226 |
+
List of IDs to which the special tokens will be added.
|
227 |
+
token_ids_1 (`List[int]`, *optional*):
|
228 |
+
Optional second list of IDs for sequence pairs.
|
229 |
+
|
230 |
+
Returns:
|
231 |
+
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
232 |
+
"""
|
233 |
+
prefix_tokens = self.get_prefix_tokens()
|
234 |
+
token_ids_0 = prefix_tokens + token_ids_0
|
235 |
+
if token_ids_1 is not None:
|
236 |
+
token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("<eos>")]
|
237 |
+
return token_ids_0
|
238 |
+
|
239 |
+
def _pad(
|
240 |
+
self,
|
241 |
+
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
242 |
+
max_length: Optional[int] = None,
|
243 |
+
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
244 |
+
pad_to_multiple_of: Optional[int] = None,
|
245 |
+
return_attention_mask: Optional[bool] = None,
|
246 |
+
) -> dict:
|
247 |
+
"""
|
248 |
+
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
249 |
+
|
250 |
+
Args:
|
251 |
+
encoded_inputs:
|
252 |
+
Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
|
253 |
+
max_length: maximum length of the returned list and optionally padding length (see below).
|
254 |
+
Will truncate by taking into account the special tokens.
|
255 |
+
padding_strategy: PaddingStrategy to use for padding.
|
256 |
+
|
257 |
+
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
|
258 |
+
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
|
259 |
+
- PaddingStrategy.DO_NOT_PAD: Do not pad
|
260 |
+
The tokenizer padding sides are defined in self.padding_side:
|
261 |
+
|
262 |
+
- 'left': pads on the left of the sequences
|
263 |
+
- 'right': pads on the right of the sequences
|
264 |
+
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
265 |
+
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
|
266 |
+
`>= 7.5` (Volta).
|
267 |
+
return_attention_mask:
|
268 |
+
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
|
269 |
+
"""
|
270 |
+
# Load from model defaults
|
271 |
+
assert self.padding_side == "left"
|
272 |
+
|
273 |
+
required_input = encoded_inputs[self.model_input_names[0]]
|
274 |
+
seq_length = len(required_input)
|
275 |
+
|
276 |
+
if padding_strategy == PaddingStrategy.LONGEST:
|
277 |
+
max_length = len(required_input)
|
278 |
+
|
279 |
+
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
280 |
+
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
281 |
+
|
282 |
+
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
283 |
+
|
284 |
+
# Initialize attention mask if not present.
|
285 |
+
if "attention_mask" not in encoded_inputs:
|
286 |
+
encoded_inputs["attention_mask"] = [1] * seq_length
|
287 |
+
|
288 |
+
if "position_ids" not in encoded_inputs:
|
289 |
+
encoded_inputs["position_ids"] = list(range(seq_length))
|
290 |
+
|
291 |
+
if needs_to_be_padded:
|
292 |
+
difference = max_length - len(required_input)
|
293 |
+
|
294 |
+
if "attention_mask" in encoded_inputs:
|
295 |
+
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
|
296 |
+
if "position_ids" in encoded_inputs:
|
297 |
+
encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
|
298 |
+
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
|
299 |
+
|
300 |
+
return encoded_inputs
|
chatglm3-6b/tokenizer.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e7dc4c393423b76e4373e5157ddc34803a0189ba96b21ddbb40269d31468a6f2
|
3 |
+
size 1018370
|
chatglm3-6b/tokenizer_config.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name_or_path": "THUDM/chatglm3-6b",
|
3 |
+
"remove_space": false,
|
4 |
+
"do_lower_case": false,
|
5 |
+
"tokenizer_class": "ChatGLMTokenizer",
|
6 |
+
"auto_map": {
|
7 |
+
"AutoTokenizer": [
|
8 |
+
"tokenization_chatglm.ChatGLMTokenizer",
|
9 |
+
null
|
10 |
+
]
|
11 |
+
}
|
12 |
+
}
|
composite_demo/.streamlit/config.toml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
[theme]
|
2 |
+
font = "monospace"
|
composite_demo/README.md
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ChatGLM3 Web Demo
|
2 |
+
|
3 |
+

|
4 |
+
|
5 |
+
## 安装
|
6 |
+
|
7 |
+
我们建议通过 [Conda](https://docs.conda.io/en/latest/) 进行环境管理。
|
8 |
+
|
9 |
+
执行以下命令新建一个 conda 环境并安装所需依赖:
|
10 |
+
|
11 |
+
```bash
|
12 |
+
conda create -n chatglm3-demo python=3.10
|
13 |
+
conda activate chatglm3-demo
|
14 |
+
pip install -r requirements.txt
|
15 |
+
```
|
16 |
+
|
17 |
+
请注意,本项目需要 Python 3.10 或更高版本。
|
18 |
+
|
19 |
+
此外,使用 Code Interpreter 还需要安装 Jupyter 内核:
|
20 |
+
|
21 |
+
```bash
|
22 |
+
ipython kernel install --name chatglm3-demo --user
|
23 |
+
```
|
24 |
+
|
25 |
+
## 运行
|
26 |
+
|
27 |
+
运行以下命令在本地加载模型并启动 demo:
|
28 |
+
|
29 |
+
```bash
|
30 |
+
streamlit run main.py
|
31 |
+
```
|
32 |
+
|
33 |
+
之后即可从命令行中看到 demo 的地址,点击即可访问。初次访问需要下载并加载模型,可能需要花费一定时间。
|
34 |
+
|
35 |
+
如果已经在本地下载了模型,可以通过 `export MODEL_PATH=/path/to/model` 来指定从本地加载模型。如果需要自定义 Jupyter 内核,可以通过 `export IPYKERNEL=<kernel_name>` 来指定。
|
36 |
+
|
37 |
+
## 使用
|
38 |
+
|
39 |
+
ChatGLM3 Demo 拥有三种模式:
|
40 |
+
|
41 |
+
- Chat: 对话模式,在此模式下可以与模型进行对话。
|
42 |
+
- Tool: 工具模式,模型除了对话外,还可以通过工具进行其他操作。
|
43 |
+
- Code Interpreter: 代码解释器模式,模型可以在一个 Jupyter 环境中执行代码并获取结果,以完成复杂任务。
|
44 |
+
|
45 |
+
### 对话模式
|
46 |
+
|
47 |
+
对话模式下,用户可以直接在侧边栏修改 top_p, temperature, System Prompt 等参数来调整模型的行为。例如
|
48 |
+
|
49 |
+

|
50 |
+
|
51 |
+
### 工具模式
|
52 |
+
|
53 |
+
可以通过在 `tool_registry.py` 中注册新的工具来增强模型的能力。只需要使用 `@register_tool` 装饰函数即可完成注册。对于工具声明,函数名称即为工具的名称,函数 docstring 即为工具的说明;对于工具的参数,使用 `Annotated[typ: type, description: str, required: bool]` 标注参数的类型、描述和是否必须。
|
54 |
+
|
55 |
+
例如,`get_weather` 工具的注册如下:
|
56 |
+
|
57 |
+
```python
|
58 |
+
@register_tool
|
59 |
+
def get_weather(
|
60 |
+
city_name: Annotated[str, 'The name of the city to be queried', True],
|
61 |
+
) -> str:
|
62 |
+
"""
|
63 |
+
Get the weather for `city_name` in the following week
|
64 |
+
"""
|
65 |
+
...
|
66 |
+
```
|
67 |
+
|
68 |
+

|
69 |
+
|
70 |
+
此外,你也可以在页面中通过 `Manual mode` 进入手动模式,在这一模式下你可以通过 YAML 来直接指定工具列表,但你需要手动将工具的输出反馈给模型。
|
71 |
+
|
72 |
+
### 代码解释器模式
|
73 |
+
|
74 |
+
由于拥有代码执行环境,此模式下的模型能够执行更为复杂的任务,例如绘制图表、执行符号运算等等。模型会根据对任务完成情况的理解自动地连续执行多个代码块,直到任务完成。因此,在这一模式下,你只需要指明希望模型执行的任务即可。
|
75 |
+
|
76 |
+
例如,我们可以让 ChatGLM3 画一个爱心:
|
77 |
+
|
78 |
+

|
79 |
+
|
80 |
+
### 额外技巧
|
81 |
+
|
82 |
+
- 在模型生成文本时,可以通过页面右上角的 `Stop` 按钮进行打断。
|
83 |
+
- 刷新页面即可清空对话记录。
|
84 |
+
|
85 |
+
# Enjoy!
|
composite_demo/README_en.md
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ChatGLM3 Web Demo
|
2 |
+
|
3 |
+

|
4 |
+
|
5 |
+
## Installation
|
6 |
+
|
7 |
+
We recommend managing environments through [Conda](https://docs.conda.io/en/latest/).
|
8 |
+
|
9 |
+
Execute the following commands to create a new conda environment and install the necessary dependencies:
|
10 |
+
|
11 |
+
```bash
|
12 |
+
conda create -n chatglm3-demo python=3.10
|
13 |
+
conda activate chatglm3-demo
|
14 |
+
pip install -r requirements.txt
|
15 |
+
```
|
16 |
+
|
17 |
+
Please note that this project requires Python 3.10 or higher.
|
18 |
+
|
19 |
+
Additionally, installing the Jupyter kernel is required for using the Code Interpreter:
|
20 |
+
|
21 |
+
```bash
|
22 |
+
ipython kernel install --name chatglm3-demo --user
|
23 |
+
```
|
24 |
+
|
25 |
+
## Execution
|
26 |
+
|
27 |
+
Run the following command to load the model locally and start the demo:
|
28 |
+
|
29 |
+
```bash
|
30 |
+
streamlit run main.py
|
31 |
+
```
|
32 |
+
|
33 |
+
Afterward, the address of the demo can be seen from the command line; click to access. The first visit requires the download and loading of the model, which may take some time.
|
34 |
+
|
35 |
+
If the model has already been downloaded locally, you can specify to load the model locally through `export MODEL_PATH=/path/to/model`. If you need to customize the Jupyter kernel, you can specify it through `export IPYKERNEL=<kernel_name>`.
|
36 |
+
|
37 |
+
## Usage
|
38 |
+
|
39 |
+
ChatGLM3 Demo has three modes:
|
40 |
+
|
41 |
+
- Chat: Dialogue mode, where you can interact with the model.
|
42 |
+
- Tool: Tool mode, where the model, in addition to dialogue, can perform other operations through tools.
|
43 |
+
- Code Interpreter: Code interpreter mode, where the model can execute code in a Jupyter environment and obtain results to complete complex tasks.
|
44 |
+
|
45 |
+
### Dialogue Mode
|
46 |
+
|
47 |
+
In dialogue mode, users can directly modify parameters such as top_p, temperature, System Prompt in the sidebar to adjust the behavior of the model. For example,
|
48 |
+
|
49 |
+

|
50 |
+
|
51 |
+
### Tool Mode
|
52 |
+
|
53 |
+
You can enhance the model's capabilities by registering new tools in `tool_registry.py`. Just use the `@register_tool` decorator to complete the registration. For tool declarations, the function name is the name of the tool, and the function docstring is the description of the tool; for tool parameters, use `Annotated[typ: type, description: str, required: bool]` to annotate the type, description, and whether it is necessary of the parameters.
|
54 |
+
|
55 |
+
For example, the registration of the `get_weather` tool is as follows:
|
56 |
+
|
57 |
+
```python
|
58 |
+
@register_tool
|
59 |
+
def get_weather(
|
60 |
+
city_name: Annotated[str, 'The name of the city to be queried', True],
|
61 |
+
) -> str:
|
62 |
+
"""
|
63 |
+
Get the weather for `city_name` in the following week
|
64 |
+
"""
|
65 |
+
...
|
66 |
+
```
|
67 |
+
|
68 |
+

|
69 |
+
|
70 |
+
Additionally, you can enter the manual mode through `Manual mode` on the page. In this mode, you can directly specify the tool list through YAML, but you need to manually feed back the tool's output to the model.
|
71 |
+
|
72 |
+
### Code Interpreter Mode
|
73 |
+
|
74 |
+
Due to having a code execution environment, the model in this mode can perform more complex tasks, such as drawing charts, performing symbolic operations, etc. The model will automatically execute multiple code blocks in succession based on its understanding of the task completion status until the task is completed. Therefore, in this mode, you only need to specify the task you want the model to perform.
|
75 |
+
|
76 |
+
For example, we can ask ChatGLM3 to draw a heart:
|
77 |
+
|
78 |
+

|
79 |
+
|
80 |
+
### Additional Tips
|
81 |
+
|
82 |
+
- While the model is generating text, it can be interrupted by the `Stop` button at the top right corner of the page.
|
83 |
+
- Refreshing the page will clear the dialogue history.
|
84 |
+
|
85 |
+
# Enjoy!
|
composite_demo/assets/demo.png
ADDED
![]() |
composite_demo/assets/emojis.png
ADDED
![]() |
composite_demo/assets/heart.png
ADDED
![]() |
composite_demo/assets/tool.png
ADDED
![]() |
composite_demo/client.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from collections.abc import Iterable
|
4 |
+
import os
|
5 |
+
from typing import Any, Protocol
|
6 |
+
|
7 |
+
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token
|
8 |
+
import streamlit as st
|
9 |
+
import torch
|
10 |
+
from transformers import AutoModel, AutoTokenizer, AutoConfig
|
11 |
+
|
12 |
+
from conversation import Conversation
|
13 |
+
|
14 |
+
TOOL_PROMPT = 'Answer the following questions as best as you can. You have access to the following tools:'
|
15 |
+
|
16 |
+
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')
|
17 |
+
PT_PATH = os.environ.get('PT_PATH', None)
|
18 |
+
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
|
19 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
20 |
+
|
21 |
+
# for Mac Computer like M1
|
22 |
+
# You Need Use Pytorch compiled with Metal
|
23 |
+
# DEVICE = 'mps'
|
24 |
+
|
25 |
+
# for AMD gpu likes MI100 (Not Official Steady Support yet)
|
26 |
+
# You Need Use Pytorch compiled with ROCm
|
27 |
+
# DEVICE = 'cuda'
|
28 |
+
|
29 |
+
# for Intel gpu likes A770 (Not Official Steady Support yet)
|
30 |
+
# You Need Use Pytorch compiled with oneDNN and install intel-extension-for-pytorch
|
31 |
+
# import intel_extension_for_pytorch as ipex
|
32 |
+
# DEVICE = 'xpu'
|
33 |
+
|
34 |
+
# for Moore Threads gpu like MTT S80 (Not Official Steady Support yet)
|
35 |
+
# You Need Use Pytorch compiled with Musa
|
36 |
+
# DEVICE = 'musa'
|
37 |
+
|
38 |
+
|
39 |
+
@st.cache_resource
|
40 |
+
def get_client() -> Client:
|
41 |
+
client = HFClient(MODEL_PATH, TOKENIZER_PATH, PT_PATH, DEVICE)
|
42 |
+
return client
|
43 |
+
|
44 |
+
|
45 |
+
class Client(Protocol):
|
46 |
+
def generate_stream(self,
|
47 |
+
system: str | None,
|
48 |
+
tools: list[dict] | None,
|
49 |
+
history: list[Conversation],
|
50 |
+
**parameters: Any
|
51 |
+
) -> Iterable[TextGenerationStreamResponse]:
|
52 |
+
...
|
53 |
+
|
54 |
+
|
55 |
+
def stream_chat(self, tokenizer, query: str, history: list[tuple[str, str]] = None, role: str = "user",
|
56 |
+
past_key_values=None, max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
|
57 |
+
repetition_penalty=1.0, length_penalty=1.0, num_beams=1,
|
58 |
+
logits_processor=None, return_past_key_values=False, **kwargs):
|
59 |
+
from transformers.generation.logits_process import LogitsProcessor
|
60 |
+
from transformers.generation.utils import LogitsProcessorList
|
61 |
+
|
62 |
+
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
63 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
64 |
+
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
65 |
+
scores.zero_()
|
66 |
+
scores[..., 5] = 5e4
|
67 |
+
return scores
|
68 |
+
|
69 |
+
if history is None:
|
70 |
+
history = []
|
71 |
+
if logits_processor is None:
|
72 |
+
logits_processor = LogitsProcessorList()
|
73 |
+
logits_processor.append(InvalidScoreLogitsProcessor())
|
74 |
+
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
|
75 |
+
tokenizer.get_command("<|observation|>")]
|
76 |
+
gen_kwargs = {"max_length": max_length,
|
77 |
+
"do_sample": do_sample,
|
78 |
+
"top_p": top_p,
|
79 |
+
"temperature": temperature,
|
80 |
+
"logits_processor": logits_processor,
|
81 |
+
"repetition_penalty": repetition_penalty,
|
82 |
+
"length_penalty": length_penalty,
|
83 |
+
"num_beams": num_beams,
|
84 |
+
**kwargs
|
85 |
+
}
|
86 |
+
|
87 |
+
print(gen_kwargs)
|
88 |
+
if past_key_values is None:
|
89 |
+
inputs = tokenizer.build_chat_input(query, history=history, role=role)
|
90 |
+
else:
|
91 |
+
inputs = tokenizer.build_chat_input(query, role=role)
|
92 |
+
inputs = inputs.to(self.device)
|
93 |
+
if past_key_values is not None:
|
94 |
+
past_length = past_key_values[0][0].shape[0]
|
95 |
+
if self.transformer.pre_seq_len is not None:
|
96 |
+
past_length -= self.transformer.pre_seq_len
|
97 |
+
inputs.position_ids += past_length
|
98 |
+
attention_mask = inputs.attention_mask
|
99 |
+
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
|
100 |
+
inputs['attention_mask'] = attention_mask
|
101 |
+
history.append({"role": role, "content": query})
|
102 |
+
print("input_shape>", inputs['input_ids'].shape)
|
103 |
+
|
104 |
+
input_sequence_length = inputs['input_ids'].shape[1]
|
105 |
+
|
106 |
+
if max_length < input_sequence_length <= self.config.seq_length:
|
107 |
+
yield "Current input sequence length {} exceeds sequence length set in generation parameters {}. The maximum model sequence length is {}. You may adjust the generation parameter to enable longer chat history.".format(
|
108 |
+
input_sequence_length, max_length, self.config.seq_length
|
109 |
+
), history
|
110 |
+
return
|
111 |
+
|
112 |
+
if input_sequence_length > self.config.seq_length:
|
113 |
+
yield "Current input sequence length {} exceeds maximum model sequence length {}. Unable to generate tokens.".format(
|
114 |
+
input_sequence_length, self.config.seq_length
|
115 |
+
), history
|
116 |
+
return
|
117 |
+
|
118 |
+
for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
|
119 |
+
eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
|
120 |
+
**gen_kwargs):
|
121 |
+
if return_past_key_values:
|
122 |
+
outputs, past_key_values = outputs
|
123 |
+
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
|
124 |
+
response = tokenizer.decode(outputs)
|
125 |
+
if response and response[-1] != "�":
|
126 |
+
new_history = history
|
127 |
+
if return_past_key_values:
|
128 |
+
yield response, new_history, past_key_values
|
129 |
+
else:
|
130 |
+
yield response, new_history
|
131 |
+
|
132 |
+
|
133 |
+
class HFClient(Client):
|
134 |
+
def __init__(self, model_path: str, tokenizer_path: str, pt_checkpoint: str | None = None, DEVICE = 'cpu'):
|
135 |
+
self.model_path = model_path
|
136 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
|
137 |
+
|
138 |
+
if pt_checkpoint is not None:
|
139 |
+
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True, pre_seq_len=128)
|
140 |
+
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True, config=config)
|
141 |
+
prefix_state_dict = torch.load(os.path.join(pt_checkpoint, "pytorch_model.bin"))
|
142 |
+
new_prefix_state_dict = {}
|
143 |
+
for k, v in prefix_state_dict.items():
|
144 |
+
if k.startswith("transformer.prefix_encoder."):
|
145 |
+
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
|
146 |
+
print("Loaded from pt checkpoints", new_prefix_state_dict.keys())
|
147 |
+
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
148 |
+
else:
|
149 |
+
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
|
150 |
+
|
151 |
+
self.model = self.model.to(DEVICE).eval() if 'cuda' in DEVICE else self.model.float().to(DEVICE).eval()
|
152 |
+
|
153 |
+
|
154 |
+
def generate_stream(self,
|
155 |
+
system: str | None,
|
156 |
+
tools: list[dict] | None,
|
157 |
+
history: list[Conversation],
|
158 |
+
**parameters: Any
|
159 |
+
) -> Iterable[TextGenerationStreamResponse]:
|
160 |
+
chat_history = [{
|
161 |
+
'role': 'system',
|
162 |
+
'content': system if not tools else TOOL_PROMPT,
|
163 |
+
}]
|
164 |
+
|
165 |
+
if tools:
|
166 |
+
chat_history[0]['tools'] = tools
|
167 |
+
|
168 |
+
for conversation in history[:-1]:
|
169 |
+
chat_history.append({
|
170 |
+
'role': str(conversation.role).removeprefix('<|').removesuffix('|>'),
|
171 |
+
'content': conversation.content,
|
172 |
+
})
|
173 |
+
|
174 |
+
query = history[-1].content
|
175 |
+
role = str(history[-1].role).removeprefix('<|').removesuffix('|>')
|
176 |
+
|
177 |
+
text = ''
|
178 |
+
|
179 |
+
for new_text, _ in stream_chat(self.model,
|
180 |
+
self.tokenizer,
|
181 |
+
query,
|
182 |
+
chat_history,
|
183 |
+
role,
|
184 |
+
**parameters,
|
185 |
+
):
|
186 |
+
word = new_text.removeprefix(text)
|
187 |
+
word_stripped = word.strip()
|
188 |
+
text = new_text
|
189 |
+
yield TextGenerationStreamResponse(
|
190 |
+
generated_text=text,
|
191 |
+
token=Token(
|
192 |
+
id=0,
|
193 |
+
logprob=0,
|
194 |
+
text=word,
|
195 |
+
special=word_stripped.startswith('<|') and word_stripped.endswith('|>'),
|
196 |
+
)
|
197 |
+
)
|
composite_demo/conversation.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from enum import auto, Enum
|
3 |
+
import json
|
4 |
+
|
5 |
+
from PIL.Image import Image
|
6 |
+
import streamlit as st
|
7 |
+
from streamlit.delta_generator import DeltaGenerator
|
8 |
+
|
9 |
+
TOOL_PROMPT = 'Answer the following questions as best as you can. You have access to the following tools:\n'
|
10 |
+
|
11 |
+
class Role(Enum):
|
12 |
+
SYSTEM = auto()
|
13 |
+
USER = auto()
|
14 |
+
ASSISTANT = auto()
|
15 |
+
TOOL = auto()
|
16 |
+
INTERPRETER = auto()
|
17 |
+
OBSERVATION = auto()
|
18 |
+
|
19 |
+
def __str__(self):
|
20 |
+
match self:
|
21 |
+
case Role.SYSTEM:
|
22 |
+
return "<|system|>"
|
23 |
+
case Role.USER:
|
24 |
+
return "<|user|>"
|
25 |
+
case Role.ASSISTANT | Role.TOOL | Role.INTERPRETER:
|
26 |
+
return "<|assistant|>"
|
27 |
+
case Role.OBSERVATION:
|
28 |
+
return "<|observation|>"
|
29 |
+
|
30 |
+
# Get the message block for the given role
|
31 |
+
def get_message(self):
|
32 |
+
# Compare by value here, because the enum object in the session state
|
33 |
+
# is not the same as the enum cases here, due to streamlit's rerunning
|
34 |
+
# behavior.
|
35 |
+
match self.value:
|
36 |
+
case Role.SYSTEM.value:
|
37 |
+
return
|
38 |
+
case Role.USER.value:
|
39 |
+
return st.chat_message(name="user", avatar="user")
|
40 |
+
case Role.ASSISTANT.value:
|
41 |
+
return st.chat_message(name="assistant", avatar="assistant")
|
42 |
+
case Role.TOOL.value:
|
43 |
+
return st.chat_message(name="tool", avatar="assistant")
|
44 |
+
case Role.INTERPRETER.value:
|
45 |
+
return st.chat_message(name="interpreter", avatar="assistant")
|
46 |
+
case Role.OBSERVATION.value:
|
47 |
+
return st.chat_message(name="observation", avatar="user")
|
48 |
+
case _:
|
49 |
+
st.error(f'Unexpected role: {self}')
|
50 |
+
|
51 |
+
@dataclass
|
52 |
+
class Conversation:
|
53 |
+
role: Role
|
54 |
+
content: str
|
55 |
+
tool: str | None = None
|
56 |
+
image: Image | None = None
|
57 |
+
|
58 |
+
def __str__(self) -> str:
|
59 |
+
print(self.role, self.content, self.tool)
|
60 |
+
match self.role:
|
61 |
+
case Role.SYSTEM | Role.USER | Role.ASSISTANT | Role.OBSERVATION:
|
62 |
+
return f'{self.role}\n{self.content}'
|
63 |
+
case Role.TOOL:
|
64 |
+
return f'{self.role}{self.tool}\n{self.content}'
|
65 |
+
case Role.INTERPRETER:
|
66 |
+
return f'{self.role}interpreter\n{self.content}'
|
67 |
+
|
68 |
+
# Human readable format
|
69 |
+
def get_text(self) -> str:
|
70 |
+
text = postprocess_text(self.content)
|
71 |
+
match self.role.value:
|
72 |
+
case Role.TOOL.value:
|
73 |
+
text = f'Calling tool `{self.tool}`:\n{text}'
|
74 |
+
case Role.INTERPRETER.value:
|
75 |
+
text = f'{text}'
|
76 |
+
case Role.OBSERVATION.value:
|
77 |
+
text = f'Observation:\n```\n{text}\n```'
|
78 |
+
return text
|
79 |
+
|
80 |
+
# Display as a markdown block
|
81 |
+
def show(self, placeholder: DeltaGenerator | None=None) -> str:
|
82 |
+
if placeholder:
|
83 |
+
message = placeholder
|
84 |
+
else:
|
85 |
+
message = self.role.get_message()
|
86 |
+
if self.image:
|
87 |
+
message.image(self.image)
|
88 |
+
else:
|
89 |
+
text = self.get_text()
|
90 |
+
message.markdown(text)
|
91 |
+
|
92 |
+
def preprocess_text(
|
93 |
+
system: str | None,
|
94 |
+
tools: list[dict] | None,
|
95 |
+
history: list[Conversation],
|
96 |
+
) -> str:
|
97 |
+
if tools:
|
98 |
+
tools = json.dumps(tools, indent=4, ensure_ascii=False)
|
99 |
+
|
100 |
+
prompt = f"{Role.SYSTEM}\n"
|
101 |
+
prompt += system if not tools else TOOL_PROMPT
|
102 |
+
if tools:
|
103 |
+
tools = json.loads(tools)
|
104 |
+
prompt += json.dumps(tools, ensure_ascii=False)
|
105 |
+
for conversation in history:
|
106 |
+
prompt += f'{conversation}'
|
107 |
+
prompt += f'{Role.ASSISTANT}\n'
|
108 |
+
return prompt
|
109 |
+
|
110 |
+
def postprocess_text(text: str) -> str:
|
111 |
+
text = text.replace("\(", "$")
|
112 |
+
text = text.replace("\)", "$")
|
113 |
+
text = text.replace("\[", "$$")
|
114 |
+
text = text.replace("\]", "$$")
|
115 |
+
text = text.replace("<|assistant|>", "")
|
116 |
+
text = text.replace("<|observation|>", "")
|
117 |
+
text = text.replace("<|system|>", "")
|
118 |
+
text = text.replace("<|user|>", "")
|
119 |
+
return text.strip()
|
composite_demo/demo_chat.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from streamlit.delta_generator import DeltaGenerator
|
3 |
+
|
4 |
+
from client import get_client
|
5 |
+
from conversation import postprocess_text, preprocess_text, Conversation, Role
|
6 |
+
|
7 |
+
MAX_LENGTH = 8192
|
8 |
+
|
9 |
+
client = get_client()
|
10 |
+
|
11 |
+
|
12 |
+
# Append a conversation into history, while show it in a new markdown block
|
13 |
+
def append_conversation(
|
14 |
+
conversation: Conversation,
|
15 |
+
history: list[Conversation],
|
16 |
+
placeholder: DeltaGenerator | None = None,
|
17 |
+
) -> None:
|
18 |
+
history.append(conversation)
|
19 |
+
conversation.show(placeholder)
|
20 |
+
|
21 |
+
|
22 |
+
def main(top_p: float, temperature: float, system_prompt: str, prompt_text: str, repetition_penalty: float):
|
23 |
+
placeholder = st.empty()
|
24 |
+
with placeholder.container():
|
25 |
+
if 'chat_history' not in st.session_state:
|
26 |
+
st.session_state.chat_history = []
|
27 |
+
|
28 |
+
history: list[Conversation] = st.session_state.chat_history
|
29 |
+
|
30 |
+
for conversation in history:
|
31 |
+
conversation.show()
|
32 |
+
|
33 |
+
if prompt_text:
|
34 |
+
prompt_text = prompt_text.strip()
|
35 |
+
append_conversation(Conversation(Role.USER, prompt_text), history)
|
36 |
+
|
37 |
+
input_text = preprocess_text(
|
38 |
+
system_prompt,
|
39 |
+
tools=None,
|
40 |
+
history=history,
|
41 |
+
)
|
42 |
+
print("=== Input:")
|
43 |
+
print(input_text)
|
44 |
+
print("=== History:")
|
45 |
+
print(history)
|
46 |
+
|
47 |
+
placeholder = st.empty()
|
48 |
+
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
|
49 |
+
markdown_placeholder = message_placeholder.empty()
|
50 |
+
|
51 |
+
output_text = ''
|
52 |
+
for response in client.generate_stream(
|
53 |
+
system_prompt,
|
54 |
+
tools=None,
|
55 |
+
history=history,
|
56 |
+
do_sample=True,
|
57 |
+
max_length=MAX_LENGTH,
|
58 |
+
temperature=temperature,
|
59 |
+
top_p=top_p,
|
60 |
+
stop_sequences=[str(Role.USER)],
|
61 |
+
repetition_penalty=repetition_penalty,
|
62 |
+
):
|
63 |
+
token = response.token
|
64 |
+
if response.token.special:
|
65 |
+
print("=== Output:")
|
66 |
+
print(output_text)
|
67 |
+
|
68 |
+
match token.text.strip():
|
69 |
+
case '<|user|>':
|
70 |
+
break
|
71 |
+
case _:
|
72 |
+
st.error(f'Unexpected special token: {token.text.strip()}')
|
73 |
+
break
|
74 |
+
output_text += response.token.text
|
75 |
+
markdown_placeholder.markdown(postprocess_text(output_text + '▌'))
|
76 |
+
|
77 |
+
append_conversation(Conversation(
|
78 |
+
Role.ASSISTANT,
|
79 |
+
postprocess_text(output_text),
|
80 |
+
), history, markdown_placeholder)
|
composite_demo/demo_ci.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
from io import BytesIO
|
3 |
+
import os
|
4 |
+
from pprint import pprint
|
5 |
+
import queue
|
6 |
+
import re
|
7 |
+
from subprocess import PIPE
|
8 |
+
|
9 |
+
import jupyter_client
|
10 |
+
from PIL import Image
|
11 |
+
import streamlit as st
|
12 |
+
from streamlit.delta_generator import DeltaGenerator
|
13 |
+
|
14 |
+
from client import get_client
|
15 |
+
from conversation import postprocess_text, preprocess_text, Conversation, Role
|
16 |
+
|
17 |
+
IPYKERNEL = os.environ.get('IPYKERNEL', 'chatglm3-demo')
|
18 |
+
|
19 |
+
SYSTEM_PROMPT = '你是一位智能AI助手,你叫ChatGLM,你连接着一台电脑,但请注意不能联网。在使用Python解决任务时,你可以运行代码并得到结果,如果运行结果有错误,你需要尽可能对代码进行改进。你可以处理用户上传到电脑上的文件,文件默认存储路径是/mnt/data/。'
|
20 |
+
|
21 |
+
MAX_LENGTH = 8192
|
22 |
+
TRUNCATE_LENGTH = 1024
|
23 |
+
|
24 |
+
client = get_client()
|
25 |
+
|
26 |
+
class CodeKernel(object):
|
27 |
+
def __init__(self,
|
28 |
+
kernel_name='kernel',
|
29 |
+
kernel_id=None,
|
30 |
+
kernel_config_path="",
|
31 |
+
python_path=None,
|
32 |
+
ipython_path=None,
|
33 |
+
init_file_path="./startup.py",
|
34 |
+
verbose=1):
|
35 |
+
|
36 |
+
self.kernel_name = kernel_name
|
37 |
+
self.kernel_id = kernel_id
|
38 |
+
self.kernel_config_path = kernel_config_path
|
39 |
+
self.python_path = python_path
|
40 |
+
self.ipython_path = ipython_path
|
41 |
+
self.init_file_path = init_file_path
|
42 |
+
self.verbose = verbose
|
43 |
+
|
44 |
+
if python_path is None and ipython_path is None:
|
45 |
+
env = None
|
46 |
+
else:
|
47 |
+
env = {"PATH": self.python_path + ":$PATH", "PYTHONPATH": self.python_path}
|
48 |
+
|
49 |
+
# Initialize the backend kernel
|
50 |
+
self.kernel_manager = jupyter_client.KernelManager(kernel_name=IPYKERNEL,
|
51 |
+
connection_file=self.kernel_config_path,
|
52 |
+
exec_files=[self.init_file_path],
|
53 |
+
env=env)
|
54 |
+
if self.kernel_config_path:
|
55 |
+
self.kernel_manager.load_connection_file()
|
56 |
+
self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE)
|
57 |
+
print("Backend kernel started with the configuration: {}".format(
|
58 |
+
self.kernel_config_path))
|
59 |
+
else:
|
60 |
+
self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE)
|
61 |
+
print("Backend kernel started with the configuration: {}".format(
|
62 |
+
self.kernel_manager.connection_file))
|
63 |
+
|
64 |
+
if verbose:
|
65 |
+
pprint(self.kernel_manager.get_connection_info())
|
66 |
+
|
67 |
+
# Initialize the code kernel
|
68 |
+
self.kernel = self.kernel_manager.blocking_client()
|
69 |
+
# self.kernel.load_connection_file()
|
70 |
+
self.kernel.start_channels()
|
71 |
+
print("Code kernel started.")
|
72 |
+
|
73 |
+
def execute(self, code):
|
74 |
+
self.kernel.execute(code)
|
75 |
+
try:
|
76 |
+
shell_msg = self.kernel.get_shell_msg(timeout=30)
|
77 |
+
io_msg_content = self.kernel.get_iopub_msg(timeout=30)['content']
|
78 |
+
while True:
|
79 |
+
msg_out = io_msg_content
|
80 |
+
### Poll the message
|
81 |
+
try:
|
82 |
+
io_msg_content = self.kernel.get_iopub_msg(timeout=30)['content']
|
83 |
+
if 'execution_state' in io_msg_content and io_msg_content['execution_state'] == 'idle':
|
84 |
+
break
|
85 |
+
except queue.Empty:
|
86 |
+
break
|
87 |
+
|
88 |
+
return shell_msg, msg_out
|
89 |
+
except Exception as e:
|
90 |
+
print(e)
|
91 |
+
return None
|
92 |
+
|
93 |
+
def execute_interactive(self, code, verbose=False):
|
94 |
+
shell_msg = self.kernel.execute_interactive(code)
|
95 |
+
if shell_msg is queue.Empty:
|
96 |
+
if verbose:
|
97 |
+
print("Timeout waiting for shell message.")
|
98 |
+
self.check_msg(shell_msg, verbose=verbose)
|
99 |
+
|
100 |
+
return shell_msg
|
101 |
+
|
102 |
+
def inspect(self, code, verbose=False):
|
103 |
+
msg_id = self.kernel.inspect(code)
|
104 |
+
shell_msg = self.kernel.get_shell_msg(timeout=30)
|
105 |
+
if shell_msg is queue.Empty:
|
106 |
+
if verbose:
|
107 |
+
print("Timeout waiting for shell message.")
|
108 |
+
self.check_msg(shell_msg, verbose=verbose)
|
109 |
+
|
110 |
+
return shell_msg
|
111 |
+
|
112 |
+
def get_error_msg(self, msg, verbose=False) -> str | None:
|
113 |
+
if msg['content']['status'] == 'error':
|
114 |
+
try:
|
115 |
+
error_msg = msg['content']['traceback']
|
116 |
+
except:
|
117 |
+
try:
|
118 |
+
error_msg = msg['content']['traceback'][-1].strip()
|
119 |
+
except:
|
120 |
+
error_msg = "Traceback Error"
|
121 |
+
if verbose:
|
122 |
+
print("Error: ", error_msg)
|
123 |
+
return error_msg
|
124 |
+
return None
|
125 |
+
|
126 |
+
def check_msg(self, msg, verbose=False):
|
127 |
+
status = msg['content']['status']
|
128 |
+
if status == 'ok':
|
129 |
+
if verbose:
|
130 |
+
print("Execution succeeded.")
|
131 |
+
elif status == 'error':
|
132 |
+
for line in msg['content']['traceback']:
|
133 |
+
if verbose:
|
134 |
+
print(line)
|
135 |
+
|
136 |
+
def shutdown(self):
|
137 |
+
# Shutdown the backend kernel
|
138 |
+
self.kernel_manager.shutdown_kernel()
|
139 |
+
print("Backend kernel shutdown.")
|
140 |
+
# Shutdown the code kernel
|
141 |
+
self.kernel.shutdown()
|
142 |
+
print("Code kernel shutdown.")
|
143 |
+
|
144 |
+
def restart(self):
|
145 |
+
# Restart the backend kernel
|
146 |
+
self.kernel_manager.restart_kernel()
|
147 |
+
# print("Backend kernel restarted.")
|
148 |
+
|
149 |
+
def interrupt(self):
|
150 |
+
# Interrupt the backend kernel
|
151 |
+
self.kernel_manager.interrupt_kernel()
|
152 |
+
# print("Backend kernel interrupted.")
|
153 |
+
|
154 |
+
def is_alive(self):
|
155 |
+
return self.kernel.is_alive()
|
156 |
+
|
157 |
+
def b64_2_img(data):
|
158 |
+
buff = BytesIO(base64.b64decode(data))
|
159 |
+
return Image.open(buff)
|
160 |
+
|
161 |
+
def clean_ansi_codes(input_string):
|
162 |
+
ansi_escape = re.compile(r'(\x9B|\x1B\[|\u001b\[)[0-?]*[ -/]*[@-~]')
|
163 |
+
return ansi_escape.sub('', input_string)
|
164 |
+
|
165 |
+
def execute(code, kernel: CodeKernel) -> tuple[str, str | Image.Image]:
|
166 |
+
res = ""
|
167 |
+
res_type = None
|
168 |
+
code = code.replace("<|observation|>", "")
|
169 |
+
code = code.replace("<|assistant|>interpreter", "")
|
170 |
+
code = code.replace("<|assistant|>", "")
|
171 |
+
code = code.replace("<|user|>", "")
|
172 |
+
code = code.replace("<|system|>", "")
|
173 |
+
msg, output = kernel.execute(code)
|
174 |
+
|
175 |
+
if msg['metadata']['status'] == "timeout":
|
176 |
+
return res_type, 'Timed out'
|
177 |
+
elif msg['metadata']['status'] == 'error':
|
178 |
+
return res_type, clean_ansi_codes('\n'.join(kernel.get_error_msg(msg, verbose=True)))
|
179 |
+
|
180 |
+
if 'text' in output:
|
181 |
+
res_type = "text"
|
182 |
+
res = output['text']
|
183 |
+
elif 'data' in output:
|
184 |
+
for key in output['data']:
|
185 |
+
if 'text/plain' in key:
|
186 |
+
res_type = "text"
|
187 |
+
res = output['data'][key]
|
188 |
+
elif 'image/png' in key:
|
189 |
+
res_type = "image"
|
190 |
+
res = output['data'][key]
|
191 |
+
break
|
192 |
+
|
193 |
+
if res_type == "image":
|
194 |
+
return res_type, b64_2_img(res)
|
195 |
+
elif res_type == "text" or res_type == "traceback":
|
196 |
+
res = res
|
197 |
+
|
198 |
+
return res_type, res
|
199 |
+
|
200 |
+
@st.cache_resource
|
201 |
+
def get_kernel():
|
202 |
+
kernel = CodeKernel()
|
203 |
+
return kernel
|
204 |
+
|
205 |
+
def extract_code(text: str) -> str:
|
206 |
+
pattern = r'```([^\n]*)\n(.*?)```'
|
207 |
+
matches = re.findall(pattern, text, re.DOTALL)
|
208 |
+
return matches[-1][1]
|
209 |
+
|
210 |
+
# Append a conversation into history, while show it in a new markdown block
|
211 |
+
def append_conversation(
|
212 |
+
conversation: Conversation,
|
213 |
+
history: list[Conversation],
|
214 |
+
placeholder: DeltaGenerator | None=None,
|
215 |
+
) -> None:
|
216 |
+
history.append(conversation)
|
217 |
+
conversation.show(placeholder)
|
218 |
+
|
219 |
+
def main(top_p: float, temperature: float, prompt_text: str, repetition_penalty: float):
|
220 |
+
if 'ci_history' not in st.session_state:
|
221 |
+
st.session_state.ci_history = []
|
222 |
+
|
223 |
+
history: list[Conversation] = st.session_state.ci_history
|
224 |
+
|
225 |
+
for conversation in history:
|
226 |
+
conversation.show()
|
227 |
+
|
228 |
+
if prompt_text:
|
229 |
+
prompt_text = prompt_text.strip()
|
230 |
+
role = Role.USER
|
231 |
+
append_conversation(Conversation(role, prompt_text), history)
|
232 |
+
|
233 |
+
input_text = preprocess_text(
|
234 |
+
SYSTEM_PROMPT,
|
235 |
+
None,
|
236 |
+
history,
|
237 |
+
)
|
238 |
+
print("=== Input:")
|
239 |
+
print(input_text)
|
240 |
+
print("=== History:")
|
241 |
+
print(history)
|
242 |
+
|
243 |
+
placeholder = st.container()
|
244 |
+
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
|
245 |
+
markdown_placeholder = message_placeholder.empty()
|
246 |
+
|
247 |
+
for _ in range(5):
|
248 |
+
output_text = ''
|
249 |
+
for response in client.generate_stream(
|
250 |
+
system=SYSTEM_PROMPT,
|
251 |
+
tools=None,
|
252 |
+
history=history,
|
253 |
+
do_sample=True,
|
254 |
+
max_length=MAX_LENGTH,
|
255 |
+
temperature=temperature,
|
256 |
+
top_p=top_p,
|
257 |
+
stop_sequences=[str(r) for r in (Role.USER, Role.OBSERVATION)],
|
258 |
+
repetition_penalty=repetition_penalty,
|
259 |
+
):
|
260 |
+
token = response.token
|
261 |
+
if response.token.special:
|
262 |
+
print("=== Output:")
|
263 |
+
print(output_text)
|
264 |
+
|
265 |
+
match token.text.strip():
|
266 |
+
case '<|user|>':
|
267 |
+
append_conversation(Conversation(
|
268 |
+
Role.ASSISTANT,
|
269 |
+
postprocess_text(output_text),
|
270 |
+
), history, markdown_placeholder)
|
271 |
+
return
|
272 |
+
# Initiate tool call
|
273 |
+
case '<|assistant|>':
|
274 |
+
append_conversation(Conversation(
|
275 |
+
Role.ASSISTANT,
|
276 |
+
postprocess_text(output_text),
|
277 |
+
), history, markdown_placeholder)
|
278 |
+
message_placeholder = placeholder.chat_message(name="interpreter", avatar="assistant")
|
279 |
+
markdown_placeholder = message_placeholder.empty()
|
280 |
+
output_text = ''
|
281 |
+
continue
|
282 |
+
case '<|observation|>':
|
283 |
+
code = extract_code(output_text)
|
284 |
+
print("Code:", code)
|
285 |
+
|
286 |
+
display_text = output_text.split('interpreter')[-1].strip()
|
287 |
+
append_conversation(Conversation(
|
288 |
+
Role.INTERPRETER,
|
289 |
+
postprocess_text(display_text),
|
290 |
+
), history, markdown_placeholder)
|
291 |
+
message_placeholder = placeholder.chat_message(name="observation", avatar="user")
|
292 |
+
markdown_placeholder = message_placeholder.empty()
|
293 |
+
output_text = ''
|
294 |
+
|
295 |
+
with markdown_placeholder:
|
296 |
+
with st.spinner('Executing code...'):
|
297 |
+
try:
|
298 |
+
res_type, res = execute(code, get_kernel())
|
299 |
+
except Exception as e:
|
300 |
+
st.error(f'Error when executing code: {e}')
|
301 |
+
return
|
302 |
+
print("Received:", res_type, res)
|
303 |
+
|
304 |
+
if res_type == 'text' and len(res) > TRUNCATE_LENGTH:
|
305 |
+
res = res[:TRUNCATE_LENGTH] + ' [TRUNCATED]'
|
306 |
+
|
307 |
+
append_conversation(Conversation(
|
308 |
+
Role.OBSERVATION,
|
309 |
+
'[Image]' if res_type == 'image' else postprocess_text(res),
|
310 |
+
tool=None,
|
311 |
+
image=res if res_type == 'image' else None,
|
312 |
+
), history, markdown_placeholder)
|
313 |
+
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
|
314 |
+
markdown_placeholder = message_placeholder.empty()
|
315 |
+
output_text = ''
|
316 |
+
break
|
317 |
+
case _:
|
318 |
+
st.error(f'Unexpected special token: {token.text.strip()}')
|
319 |
+
break
|
320 |
+
output_text += response.token.text
|
321 |
+
display_text = output_text.split('interpreter')[-1].strip()
|
322 |
+
markdown_placeholder.markdown(postprocess_text(display_text + '▌'))
|
323 |
+
else:
|
324 |
+
append_conversation(Conversation(
|
325 |
+
Role.ASSISTANT,
|
326 |
+
postprocess_text(output_text),
|
327 |
+
), history, markdown_placeholder)
|
328 |
+
return
|
329 |
+
|
composite_demo/demo_tool.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import yaml
|
3 |
+
from yaml import YAMLError
|
4 |
+
|
5 |
+
import streamlit as st
|
6 |
+
from streamlit.delta_generator import DeltaGenerator
|
7 |
+
|
8 |
+
from client import get_client
|
9 |
+
from conversation import postprocess_text, preprocess_text, Conversation, Role
|
10 |
+
from tool_registry import dispatch_tool, get_tools
|
11 |
+
|
12 |
+
MAX_LENGTH = 8192
|
13 |
+
TRUNCATE_LENGTH = 1024
|
14 |
+
|
15 |
+
EXAMPLE_TOOL = {
|
16 |
+
"name": "get_current_weather",
|
17 |
+
"description": "Get the current weather in a given location",
|
18 |
+
"parameters": {
|
19 |
+
"type": "object",
|
20 |
+
"properties": {
|
21 |
+
"location": {
|
22 |
+
"type": "string",
|
23 |
+
"description": "The city and state, e.g. San Francisco, CA",
|
24 |
+
},
|
25 |
+
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
26 |
+
},
|
27 |
+
"required": ["location"],
|
28 |
+
}
|
29 |
+
}
|
30 |
+
|
31 |
+
client = get_client()
|
32 |
+
|
33 |
+
def tool_call(*args, **kwargs) -> dict:
|
34 |
+
print("=== Tool call:")
|
35 |
+
print(args)
|
36 |
+
print(kwargs)
|
37 |
+
st.session_state.calling_tool = True
|
38 |
+
return kwargs
|
39 |
+
|
40 |
+
def yaml_to_dict(tools: str) -> list[dict] | None:
|
41 |
+
try:
|
42 |
+
return yaml.safe_load(tools)
|
43 |
+
except YAMLError:
|
44 |
+
return None
|
45 |
+
|
46 |
+
def extract_code(text: str) -> str:
|
47 |
+
pattern = r'```([^\n]*)\n(.*?)```'
|
48 |
+
matches = re.findall(pattern, text, re.DOTALL)
|
49 |
+
return matches[-1][1]
|
50 |
+
|
51 |
+
# Append a conversation into history, while show it in a new markdown block
|
52 |
+
def append_conversation(
|
53 |
+
conversation: Conversation,
|
54 |
+
history: list[Conversation],
|
55 |
+
placeholder: DeltaGenerator | None=None,
|
56 |
+
) -> None:
|
57 |
+
history.append(conversation)
|
58 |
+
conversation.show(placeholder)
|
59 |
+
|
60 |
+
def main(top_p: float, temperature: float, prompt_text: str, repetition_penalty: float):
|
61 |
+
manual_mode = st.toggle('Manual mode',
|
62 |
+
help='Define your tools in YAML format. You need to supply tool call results manually.'
|
63 |
+
)
|
64 |
+
|
65 |
+
if manual_mode:
|
66 |
+
with st.expander('Tools'):
|
67 |
+
tools = st.text_area(
|
68 |
+
'Define your tools in YAML format here:',
|
69 |
+
yaml.safe_dump([EXAMPLE_TOOL], sort_keys=False),
|
70 |
+
height=400,
|
71 |
+
)
|
72 |
+
tools = yaml_to_dict(tools)
|
73 |
+
|
74 |
+
if not tools:
|
75 |
+
st.error('YAML format error in tools definition')
|
76 |
+
else:
|
77 |
+
tools = get_tools()
|
78 |
+
|
79 |
+
if 'tool_history' not in st.session_state:
|
80 |
+
st.session_state.tool_history = []
|
81 |
+
if 'calling_tool' not in st.session_state:
|
82 |
+
st.session_state.calling_tool = False
|
83 |
+
|
84 |
+
history: list[Conversation] = st.session_state.tool_history
|
85 |
+
|
86 |
+
for conversation in history:
|
87 |
+
conversation.show()
|
88 |
+
|
89 |
+
if prompt_text:
|
90 |
+
prompt_text = prompt_text.strip()
|
91 |
+
role = st.session_state.calling_tool and Role.OBSERVATION or Role.USER
|
92 |
+
append_conversation(Conversation(role, prompt_text), history)
|
93 |
+
st.session_state.calling_tool = False
|
94 |
+
|
95 |
+
input_text = preprocess_text(
|
96 |
+
None,
|
97 |
+
tools,
|
98 |
+
history,
|
99 |
+
)
|
100 |
+
print("=== Input:")
|
101 |
+
print(input_text)
|
102 |
+
print("=== History:")
|
103 |
+
print(history)
|
104 |
+
|
105 |
+
placeholder = st.container()
|
106 |
+
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
|
107 |
+
markdown_placeholder = message_placeholder.empty()
|
108 |
+
|
109 |
+
for _ in range(5):
|
110 |
+
output_text = ''
|
111 |
+
for response in client.generate_stream(
|
112 |
+
system=None,
|
113 |
+
tools=tools,
|
114 |
+
history=history,
|
115 |
+
do_sample=True,
|
116 |
+
max_length=MAX_LENGTH,
|
117 |
+
temperature=temperature,
|
118 |
+
top_p=top_p,
|
119 |
+
stop_sequences=[str(r) for r in (Role.USER, Role.OBSERVATION)],
|
120 |
+
repetition_penalty=repetition_penalty,
|
121 |
+
):
|
122 |
+
token = response.token
|
123 |
+
if response.token.special:
|
124 |
+
print("=== Output:")
|
125 |
+
print(output_text)
|
126 |
+
|
127 |
+
match token.text.strip():
|
128 |
+
case '<|user|>':
|
129 |
+
append_conversation(Conversation(
|
130 |
+
Role.ASSISTANT,
|
131 |
+
postprocess_text(output_text),
|
132 |
+
), history, markdown_placeholder)
|
133 |
+
return
|
134 |
+
# Initiate tool call
|
135 |
+
case '<|assistant|>':
|
136 |
+
append_conversation(Conversation(
|
137 |
+
Role.ASSISTANT,
|
138 |
+
postprocess_text(output_text),
|
139 |
+
), history, markdown_placeholder)
|
140 |
+
output_text = ''
|
141 |
+
message_placeholder = placeholder.chat_message(name="tool", avatar="assistant")
|
142 |
+
markdown_placeholder = message_placeholder.empty()
|
143 |
+
continue
|
144 |
+
case '<|observation|>':
|
145 |
+
tool, *call_args_text = output_text.strip().split('\n')
|
146 |
+
call_args_text = '\n'.join(call_args_text)
|
147 |
+
|
148 |
+
append_conversation(Conversation(
|
149 |
+
Role.TOOL,
|
150 |
+
postprocess_text(output_text),
|
151 |
+
tool,
|
152 |
+
), history, markdown_placeholder)
|
153 |
+
message_placeholder = placeholder.chat_message(name="observation", avatar="user")
|
154 |
+
markdown_placeholder = message_placeholder.empty()
|
155 |
+
|
156 |
+
try:
|
157 |
+
code = extract_code(call_args_text)
|
158 |
+
args = eval(code, {'tool_call': tool_call}, {})
|
159 |
+
except:
|
160 |
+
st.error('Failed to parse tool call')
|
161 |
+
return
|
162 |
+
|
163 |
+
output_text = ''
|
164 |
+
|
165 |
+
if manual_mode:
|
166 |
+
st.info('Please provide tool call results below:')
|
167 |
+
return
|
168 |
+
else:
|
169 |
+
with markdown_placeholder:
|
170 |
+
with st.spinner(f'Calling tool {tool}...'):
|
171 |
+
observation = dispatch_tool(tool, args)
|
172 |
+
|
173 |
+
if len(observation) > TRUNCATE_LENGTH:
|
174 |
+
observation = observation[:TRUNCATE_LENGTH] + ' [TRUNCATED]'
|
175 |
+
append_conversation(Conversation(
|
176 |
+
Role.OBSERVATION, observation
|
177 |
+
), history, markdown_placeholder)
|
178 |
+
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
|
179 |
+
markdown_placeholder = message_placeholder.empty()
|
180 |
+
st.session_state.calling_tool = False
|
181 |
+
break
|
182 |
+
case _:
|
183 |
+
st.error(f'Unexpected special token: {token.text.strip()}')
|
184 |
+
return
|
185 |
+
output_text += response.token.text
|
186 |
+
markdown_placeholder.markdown(postprocess_text(output_text + '▌'))
|
187 |
+
else:
|
188 |
+
append_conversation(Conversation(
|
189 |
+
Role.ASSISTANT,
|
190 |
+
postprocess_text(output_text),
|
191 |
+
), history, markdown_placeholder)
|
192 |
+
return
|
composite_demo/main.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
st.set_page_config(
|
5 |
+
page_title="ChatGLM3 Demo",
|
6 |
+
page_icon=":robot:",
|
7 |
+
layout='centered',
|
8 |
+
initial_sidebar_state='expanded',
|
9 |
+
)
|
10 |
+
|
11 |
+
import demo_chat, demo_ci, demo_tool
|
12 |
+
|
13 |
+
DEFAULT_SYSTEM_PROMPT = '''
|
14 |
+
You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
|
15 |
+
'''.strip()
|
16 |
+
|
17 |
+
# Set the title of the demo
|
18 |
+
st.title("ChatGLM3 Demo")
|
19 |
+
|
20 |
+
# Add your custom text here, with smaller font size
|
21 |
+
st.markdown("<sub>智谱AI 公开在线技术文档: https://lslfd0slxc.feishu.cn/wiki/WvQbwIJ9tiPAxGk8ywDck6yfnof </sub> \n\n <sub> 更多 ChatGLM3-6B 的使用方法请参考文档。</sub>", unsafe_allow_html=True)
|
22 |
+
|
23 |
+
class Mode(str, Enum):
|
24 |
+
CHAT, TOOL, CI = '💬 Chat', '🛠️ Tool', '🧑💻 Code Interpreter'
|
25 |
+
|
26 |
+
|
27 |
+
with st.sidebar:
|
28 |
+
top_p = st.slider(
|
29 |
+
'top_p', 0.0, 1.0, 0.8, step=0.01
|
30 |
+
)
|
31 |
+
temperature = st.slider(
|
32 |
+
'temperature', 0.0, 1.5, 0.95, step=0.01
|
33 |
+
)
|
34 |
+
repetition_penalty = st.slider(
|
35 |
+
'repetition_penalty', 0.0, 2.0, 1.2, step=0.01
|
36 |
+
)
|
37 |
+
system_prompt = st.text_area(
|
38 |
+
label="System Prompt (Only for chat mode)",
|
39 |
+
height=300,
|
40 |
+
value=DEFAULT_SYSTEM_PROMPT,
|
41 |
+
)
|
42 |
+
|
43 |
+
prompt_text = st.chat_input(
|
44 |
+
'Chat with ChatGLM3!',
|
45 |
+
key='chat_input',
|
46 |
+
)
|
47 |
+
|
48 |
+
tab = st.radio(
|
49 |
+
'Mode',
|
50 |
+
[mode.value for mode in Mode],
|
51 |
+
horizontal=True,
|
52 |
+
label_visibility='hidden',
|
53 |
+
)
|
54 |
+
|
55 |
+
match tab:
|
56 |
+
case Mode.CHAT:
|
57 |
+
demo_chat.main(top_p, temperature, system_prompt, prompt_text, repetition_penalty)
|
58 |
+
case Mode.TOOL:
|
59 |
+
demo_tool.main(top_p, temperature, prompt_text, repetition_penalty)
|
60 |
+
case Mode.CI:
|
61 |
+
demo_ci.main(top_p, temperature, prompt_text, repetition_penalty)
|
62 |
+
case _:
|
63 |
+
st.error(f'Unexpected tab: {tab}')
|
composite_demo/requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
huggingface_hub
|
2 |
+
ipykernel
|
3 |
+
ipython
|
4 |
+
jupyter_client
|
5 |
+
pillow
|
6 |
+
sentencepiece
|
7 |
+
streamlit
|
8 |
+
tokenizers
|
9 |
+
torch
|
10 |
+
transformers
|
11 |
+
pyyaml
|
12 |
+
requests
|