Duplicate from THUDM/cogvlm-chat-hf
Browse filesCo-authored-by: chenkq <[email protected]>
- .gitattributes +35 -0
- README.md +174 -0
- config.json +40 -0
- configuration_cogvlm.py +45 -0
- generation_config.json +7 -0
- model-00001-of-00008.safetensors +3 -0
- model-00002-of-00008.safetensors +3 -0
- model-00003-of-00008.safetensors +3 -0
- model-00004-of-00008.safetensors +3 -0
- model-00005-of-00008.safetensors +3 -0
- model-00006-of-00008.safetensors +3 -0
- model-00007-of-00008.safetensors +3 -0
- model-00008-of-00008.safetensors +3 -0
- model.safetensors.index.json +0 -0
- modeling_cogvlm.py +783 -0
- util.py +483 -0
- visual.py +135 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
language:
|
4 |
+
- en
|
5 |
+
---
|
6 |
+
# CogVLM
|
7 |
+
|
8 |
+
**CogVLM** 是一个强大的开源视觉语言模型(VLM)。CogVLM-17B 拥有 100 亿视觉参数和 70 亿语言参数,在 10 个经典跨模态基准测试上取得了 SOTA 性能,包括 NoCaps、Flicker30k captioning、RefCOCO、RefCOCO+、RefCOCOg、Visual7W、GQA、ScienceQA、VizWiz VQA 和 TDIUC,而在 VQAv2、OKVQA、TextVQA、COCO captioning 等方面则排名第二,超越或与 PaLI-X 55B 持平。您可以通过线上 [demo](http://36.103.203.44:7861/) 体验 CogVLM 多模态对话。
|
9 |
+
|
10 |
+
**CogVLM** is a powerful **open-source visual language model** (**VLM**). CogVLM-17B has 10 billion vision parameters and 7 billion language parameters. CogVLM-17B achieves state-of-the-art performance on 10 classic cross-modal benchmarks, including NoCaps, Flicker30k captioning, RefCOCO, RefCOCO+, RefCOCOg, Visual7W, GQA, ScienceQA, VizWiz VQA and TDIUC, and rank the 2nd on VQAv2, OKVQA, TextVQA, COCO captioning, etc., **surpassing or matching PaLI-X 55B**. CogVLM can also [chat with you](http://36.103.203.44:7861/) about images.
|
11 |
+
|
12 |
+
<div align="center">
|
13 |
+
<img src="https://github.com/THUDM/CogVLM/raw/main/assets/metrics-min.png" alt="img" style="zoom: 50%;" />
|
14 |
+
</div>
|
15 |
+
|
16 |
+
# 快速开始(Qiuckstart)
|
17 |
+
|
18 |
+
硬件需求(hardware requirement)
|
19 |
+
|
20 |
+
需要近 40GB GPU 显存用于模型推理。如果没有一整块GPU显存超过40GB,则需要使用accelerate的将模型切分到多个有较小显存的GPU设备上。
|
21 |
+
|
22 |
+
40GB VRAM for inference. If there is no single GPU with more than 40GB of VRAM, you will need to use the "accelerate" library to dispatch the model into multiple GPUs with smaller VRAM.
|
23 |
+
|
24 |
+
安装依赖(dependencies)
|
25 |
+
|
26 |
+
```base
|
27 |
+
pip install torch==2.1.0 transformers==4.35.0 accelerate==0.24.1 sentencepiece==0.1.99 einops==0.7.0 xformers==0.0.22.post7 triton==2.1.0
|
28 |
+
```
|
29 |
+
|
30 |
+
代码示例(example)
|
31 |
+
|
32 |
+
```python
|
33 |
+
import torch
|
34 |
+
import requests
|
35 |
+
from PIL import Image
|
36 |
+
from transformers import AutoModelForCausalLM, LlamaTokenizer
|
37 |
+
|
38 |
+
tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
|
39 |
+
model = AutoModelForCausalLM.from_pretrained(
|
40 |
+
'THUDM/cogvlm-chat-hf',
|
41 |
+
torch_dtype=torch.bfloat16,
|
42 |
+
low_cpu_mem_usage=True,
|
43 |
+
trust_remote_code=True
|
44 |
+
).to('cuda').eval()
|
45 |
+
|
46 |
+
|
47 |
+
# chat example
|
48 |
+
query = 'Describe this image'
|
49 |
+
image = Image.open(requests.get('https://github.com/THUDM/CogVLM/blob/main/examples/1.png?raw=true', stream=True).raw).convert('RGB')
|
50 |
+
inputs = model.build_conversation_input_ids(tokenizer, query=query, history=[], images=[image]) # chat mode
|
51 |
+
inputs = {
|
52 |
+
'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
|
53 |
+
'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
|
54 |
+
'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),
|
55 |
+
'images': [[inputs['images'][0].to('cuda').to(torch.bfloat16)]],
|
56 |
+
}
|
57 |
+
gen_kwargs = {"max_length": 2048, "do_sample": False}
|
58 |
+
|
59 |
+
with torch.no_grad():
|
60 |
+
outputs = model.generate(**inputs, **gen_kwargs)
|
61 |
+
outputs = outputs[:, inputs['input_ids'].shape[1]:]
|
62 |
+
print(tokenizer.decode(outputs[0]))
|
63 |
+
|
64 |
+
# This image captures a moment from a basketball game. Two players are prominently featured: one wearing a yellow jersey with the number
|
65 |
+
# 24 and the word 'Lakers' written on it, and the other wearing a navy blue jersey with the word 'Washington' and the number 34. The player
|
66 |
+
# in yellow is holding a basketball and appears to be dribbling it, while the player in navy blue is reaching out with his arm, possibly
|
67 |
+
# trying to block or defend. The background shows a filled stadium with spectators, indicating that this is a professional game.</s>
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
# vqa example
|
72 |
+
query = 'How many houses are there in this cartoon?'
|
73 |
+
image = Image.open(requests.get('https://github.com/THUDM/CogVLM/blob/main/examples/3.jpg?raw=true', stream=True).raw).convert('RGB')
|
74 |
+
inputs = model.build_conversation_input_ids(tokenizer, query=query, history=[], images=[image], template_version='vqa') # vqa mode
|
75 |
+
inputs = {
|
76 |
+
'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
|
77 |
+
'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
|
78 |
+
'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),
|
79 |
+
'images': [[inputs['images'][0].to('cuda').to(torch.bfloat16)]],
|
80 |
+
}
|
81 |
+
gen_kwargs = {"max_length": 2048, "do_sample": False}
|
82 |
+
|
83 |
+
with torch.no_grad():
|
84 |
+
outputs = model.generate(**inputs, **gen_kwargs)
|
85 |
+
outputs = outputs[:, inputs['input_ids'].shape[1]:]
|
86 |
+
print(tokenizer.decode(outputs[0]))
|
87 |
+
|
88 |
+
# 4</s>
|
89 |
+
```
|
90 |
+
|
91 |
+
当单卡显存不足时,可以将模型切分到多个小显存GPU上。以下是个当你有两张24GB的GPU,16GBCPU内存的例子。
|
92 |
+
你可以将`infer_auto_device_map`的参数改成你的配置。注意这里将GPU显存少写了一点,这是为推理时中间状态预留出一部分显存。
|
93 |
+
|
94 |
+
dispatch the model into multiple GPUs with smaller VRAM. This is an example for you have two 24GB GPU and 16GB CPU memory.
|
95 |
+
you can change the arguments of `infer_auto_device_map` with your own setting.
|
96 |
+
|
97 |
+
```python
|
98 |
+
import torch
|
99 |
+
import requests
|
100 |
+
from PIL import Image
|
101 |
+
from transformers import AutoModelForCausalLM, LlamaTokenizer
|
102 |
+
from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_and_dispatch
|
103 |
+
|
104 |
+
tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
|
105 |
+
with init_empty_weights():
|
106 |
+
model = AutoModelForCausalLM.from_pretrained(
|
107 |
+
'THUDM/cogvlm-chat-hf',
|
108 |
+
torch_dtype=torch.bfloat16,
|
109 |
+
low_cpu_mem_usage=True,
|
110 |
+
trust_remote_code=True,
|
111 |
+
)
|
112 |
+
device_map = infer_auto_device_map(model, max_memory={0:'20GiB',1:'20GiB','cpu':'16GiB'}, no_split_module_classes='CogVLMDecoderLayer')
|
113 |
+
model = load_checkpoint_and_dispatch(
|
114 |
+
model,
|
115 |
+
'local/path/to/hf/version/chat/model', # typical, '~/.cache/huggingface/hub/models--THUDM--cogvlm-chat-hf/snapshots/balabala'
|
116 |
+
device_map=device_map,
|
117 |
+
)
|
118 |
+
model = model.eval()
|
119 |
+
|
120 |
+
# check device for weights if u want to
|
121 |
+
for n, p in model.named_parameters():
|
122 |
+
print(f"{n}: {p.device}")
|
123 |
+
|
124 |
+
# chat example
|
125 |
+
query = 'Describe this image'
|
126 |
+
image = Image.open(requests.get('https://github.com/THUDM/CogVLM/blob/main/examples/1.png?raw=true', stream=True).raw).convert('RGB')
|
127 |
+
inputs = model.build_conversation_input_ids(tokenizer, query=query, history=[], images=[image]) # chat mode
|
128 |
+
inputs = {
|
129 |
+
'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
|
130 |
+
'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
|
131 |
+
'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),
|
132 |
+
'images': [[inputs['images'][0].to('cuda').to(torch.bfloat16)]],
|
133 |
+
}
|
134 |
+
gen_kwargs = {"max_length": 2048, "do_sample": False}
|
135 |
+
|
136 |
+
with torch.no_grad():
|
137 |
+
outputs = model.generate(**inputs, **gen_kwargs)
|
138 |
+
outputs = outputs[:, inputs['input_ids'].shape[1]:]
|
139 |
+
print(tokenizer.decode(outputs[0]))
|
140 |
+
```
|
141 |
+
|
142 |
+
|
143 |
+
|
144 |
+
# 方法(Method)
|
145 |
+
|
146 |
+
CogVLM 模型包括四个基本组件:视觉变换器(ViT)编码器、MLP适配器、预训练的大型语言模型(GPT)和一个**视觉专家模块**。更多细节请参见[Paper](https://github.com/THUDM/CogVLM/blob/main/assets/cogvlm-paper.pdf)。
|
147 |
+
|
148 |
+
CogVLM model comprises four fundamental components: a vision transformer (ViT) encoder, an MLP adapter, a pretrained large language model (GPT), and a **visual expert module**. See [Paper](https://github.com/THUDM/CogVLM/blob/main/assets/cogvlm-paper.pdf) for more details.
|
149 |
+
|
150 |
+
<div align="center">
|
151 |
+
<img src="https://github.com/THUDM/CogVLM/raw/main/assets/method-min.png" style="zoom:50%;" />
|
152 |
+
</div>
|
153 |
+
|
154 |
+
# 许可(License)
|
155 |
+
|
156 |
+
此存储库中的代码是根据 [Apache-2.0 许可](https://github.com/THUDM/CogVLM/raw/main/LICENSE) 开放源码,而使用 CogVLM 模型权重必须遵循 [模型许可](https://github.com/THUDM/CogVLM/raw/main/MODEL_LICENSE)。
|
157 |
+
|
158 |
+
The code in this repository is open source under the [Apache-2.0 license](https://github.com/THUDM/CogVLM/raw/main/LICENSE), while the use of the CogVLM model weights must comply with the [Model License](https://github.com/THUDM/CogVLM/raw/main/MODEL_LICENSE).
|
159 |
+
|
160 |
+
|
161 |
+
|
162 |
+
# 引用(Citation)
|
163 |
+
|
164 |
+
If you find our work helpful, please consider citing the following papers
|
165 |
+
```
|
166 |
+
@article{wang2023cogvlm,
|
167 |
+
title={CogVLM: Visual Expert for Pretrained Language Models},
|
168 |
+
author={Weihan Wang and Qingsong Lv and Wenmeng Yu and Wenyi Hong and Ji Qi and Yan Wang and Junhui Ji and Zhuoyi Yang and Lei Zhao and Xixuan Song and Jiazheng Xu and Bin Xu and Juanzi Li and Yuxiao Dong and Ming Ding and Jie Tang},
|
169 |
+
year={2023},
|
170 |
+
eprint={2311.03079},
|
171 |
+
archivePrefix={arXiv},
|
172 |
+
primaryClass={cs.CV}
|
173 |
+
}
|
174 |
+
```
|
config.json
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "cogvlm-chat-v1.1",
|
3 |
+
"architectures": [
|
4 |
+
"CogVLMForCausalLM"
|
5 |
+
],
|
6 |
+
"auto_map": {
|
7 |
+
"AutoConfig": "configuration_cogvlm.CogVLMConfig",
|
8 |
+
"AutoModelForCausalLM": "modeling_cogvlm.CogVLMForCausalLM"
|
9 |
+
},
|
10 |
+
"bos_token_id": 1,
|
11 |
+
"eos_token_id": 2,
|
12 |
+
"hidden_act": "silu",
|
13 |
+
"hidden_size": 4096,
|
14 |
+
"initializer_range": 0.02,
|
15 |
+
"intermediate_size": 11008,
|
16 |
+
"max_position_embeddings": 2048,
|
17 |
+
"num_attention_heads": 32,
|
18 |
+
"num_hidden_layers": 32,
|
19 |
+
"pad_token_id": 0,
|
20 |
+
"rms_norm_eps": 1e-05,
|
21 |
+
"template_version": "chat",
|
22 |
+
"tie_word_embeddings": false,
|
23 |
+
"torch_dtype": "bfloat16",
|
24 |
+
"transformers_version": "4.35.0",
|
25 |
+
"use_cache": true,
|
26 |
+
"vision_config": {
|
27 |
+
"dropout_prob": 0.0,
|
28 |
+
"hidden_act": "gelu",
|
29 |
+
"hidden_size": 1792,
|
30 |
+
"image_size": 490,
|
31 |
+
"in_channels": 3,
|
32 |
+
"intermediate_size": 15360,
|
33 |
+
"layer_norm_eps": 1e-06,
|
34 |
+
"num_heads": 16,
|
35 |
+
"num_hidden_layers": 63,
|
36 |
+
"num_positions": 1226,
|
37 |
+
"patch_size": 14
|
38 |
+
},
|
39 |
+
"vocab_size": 32000
|
40 |
+
}
|
configuration_cogvlm.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal
|
2 |
+
from transformers import PretrainedConfig
|
3 |
+
|
4 |
+
|
5 |
+
class CogVLMConfig(PretrainedConfig):
|
6 |
+
_auto_class = "AutoConfig"
|
7 |
+
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
vocab_size=32000,
|
11 |
+
hidden_size=4096,
|
12 |
+
intermediate_size=11008,
|
13 |
+
num_hidden_layers=32,
|
14 |
+
num_attention_heads=32,
|
15 |
+
hidden_act='silu',
|
16 |
+
max_position_embeddings=2048,
|
17 |
+
initializer_range=0.02,
|
18 |
+
rms_norm_eps=1e-06,
|
19 |
+
template_version: Literal["base", "chat"] = "chat",
|
20 |
+
|
21 |
+
pad_token_id=0,
|
22 |
+
bos_token_id=1,
|
23 |
+
eos_token_id=2,
|
24 |
+
tie_word_embeddings=False,
|
25 |
+
use_cache=True,
|
26 |
+
**kwargs,
|
27 |
+
):
|
28 |
+
self.hidden_size = hidden_size
|
29 |
+
self.intermediate_size = intermediate_size
|
30 |
+
self.num_attention_heads = num_attention_heads
|
31 |
+
self.max_position_embeddings = max_position_embeddings
|
32 |
+
self.rms_norm_eps = rms_norm_eps
|
33 |
+
self.initializer_range = initializer_range
|
34 |
+
self.vocab_size = vocab_size
|
35 |
+
self.num_hidden_layers = num_hidden_layers
|
36 |
+
self.hidden_act = hidden_act
|
37 |
+
self.template_version = template_version
|
38 |
+
self.use_cache = use_cache
|
39 |
+
super().__init__(
|
40 |
+
pad_token_id=pad_token_id,
|
41 |
+
bos_token_id=bos_token_id,
|
42 |
+
eos_token_id=eos_token_id,
|
43 |
+
tie_word_embeddings=tie_word_embeddings,
|
44 |
+
**kwargs,
|
45 |
+
)
|
generation_config.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"bos_token_id": 1,
|
4 |
+
"eos_token_id": 2,
|
5 |
+
"pad_token_id": 0,
|
6 |
+
"transformers_version": "4.35.0"
|
7 |
+
}
|
model-00001-of-00008.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e29f6ec471ca55789ab14947b527729b9c30313ceb1e7726590b85f9f6406cca
|
3 |
+
size 4938885184
|
model-00002-of-00008.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e82356882701b1a778408f31e676d17c2aff799c543e8596ed74bc805b4a1213
|
3 |
+
size 4947290688
|
model-00003-of-00008.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:04096f84f42798d0c89319ff8254995a2a3512c16ec88dfd078ce421867d92ec
|
3 |
+
size 4947307592
|
model-00004-of-00008.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2b42af0bb16647959b3e55def4b3c66ab8c3a25fd948a5245c81d070f2b4313d
|
3 |
+
size 4991331080
|
model-00005-of-00008.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:38c07825790e055dd169376479994a58a4f59775ba7cf31d5ca25d8a465e7b0c
|
3 |
+
size 4991331088
|
model-00006-of-00008.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d01880ca5677e69a5f8632f9dda62814f0c549b5a40d4f7e136065e5d64c1a7d
|
3 |
+
size 4970162920
|
model-00007-of-00008.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e70b0e10d2ac8800e69e514b6a9b04ac28cd7db43985ce62daa4e0e639b4e5ba
|
3 |
+
size 4960543792
|
model-00008-of-00008.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a756381ef65b92af7f1fb97da3d59cb04586080982de86d76805299898223294
|
3 |
+
size 532677104
|
model.safetensors.index.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
modeling_cogvlm.py
ADDED
@@ -0,0 +1,783 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""largely copy from llama and adapt for cogvlm"""
|
2 |
+
import warnings
|
3 |
+
from typing import TYPE_CHECKING, Optional, Tuple, List, Union, Literal, Dict, Any
|
4 |
+
|
5 |
+
import math
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import CrossEntropyLoss
|
9 |
+
from torchvision import transforms
|
10 |
+
from einops import rearrange
|
11 |
+
|
12 |
+
from transformers import PreTrainedModel, PreTrainedTokenizer
|
13 |
+
from transformers.utils.logging import get_logger
|
14 |
+
from transformers.activations import ACT2FN
|
15 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
16 |
+
|
17 |
+
from .configuration_cogvlm import CogVLMConfig
|
18 |
+
from .util import FastRotaryEmbedding
|
19 |
+
from .visual import EVA2CLIPModel
|
20 |
+
|
21 |
+
if TYPE_CHECKING:
|
22 |
+
from transformers.utils import ModelOutput
|
23 |
+
|
24 |
+
logger = get_logger(__name__)
|
25 |
+
|
26 |
+
LANGUAGE_TOKEN_TYPE = 0
|
27 |
+
VISION_TOKEN_TYPE = 1
|
28 |
+
|
29 |
+
|
30 |
+
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
31 |
+
def _make_causal_mask(
|
32 |
+
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
33 |
+
):
|
34 |
+
"""
|
35 |
+
Make causal mask used for bi-directional self-attention.
|
36 |
+
"""
|
37 |
+
bsz, tgt_len = input_ids_shape
|
38 |
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
39 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
40 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
41 |
+
mask = mask.to(dtype)
|
42 |
+
|
43 |
+
if past_key_values_length > 0:
|
44 |
+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
45 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
46 |
+
|
47 |
+
|
48 |
+
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
49 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
50 |
+
"""
|
51 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
52 |
+
"""
|
53 |
+
bsz, src_len = mask.size()
|
54 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
55 |
+
|
56 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
57 |
+
|
58 |
+
inverted_mask = 1.0 - expanded_mask
|
59 |
+
|
60 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
61 |
+
|
62 |
+
|
63 |
+
class RMSNorm(nn.Module):
|
64 |
+
def __init__(self, hidden_size, eps=1e-6):
|
65 |
+
super().__init__()
|
66 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
67 |
+
self.variance_epsilon = eps
|
68 |
+
|
69 |
+
def forward(self, hidden_states):
|
70 |
+
input_dtype = hidden_states.dtype
|
71 |
+
hidden_states = hidden_states.to(torch.float32)
|
72 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
73 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
74 |
+
return (self.weight * hidden_states).to(input_dtype)
|
75 |
+
|
76 |
+
|
77 |
+
class MLP(nn.Module):
|
78 |
+
def __init__(self, config):
|
79 |
+
super().__init__()
|
80 |
+
self.hidden_size = config.hidden_size
|
81 |
+
self.intermediate_size = config.intermediate_size
|
82 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
83 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
84 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
85 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
89 |
+
return down_proj
|
90 |
+
|
91 |
+
|
92 |
+
def get_expert_mask(token_type_ids: "torch.LongTensor(B, L)") -> "[torch.BoolTensor(B, L), torch.BoolTensor(B, L)]":
|
93 |
+
vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool)
|
94 |
+
vision_token_mask[:, :-1] = (token_type_ids[:, :-1] == VISION_TOKEN_TYPE) & (token_type_ids[:, 1:] == VISION_TOKEN_TYPE)
|
95 |
+
language_token_mask = ~vision_token_mask
|
96 |
+
return vision_token_mask, language_token_mask
|
97 |
+
|
98 |
+
|
99 |
+
class VisionExpertMLP(nn.Module):
|
100 |
+
def __init__(self, config):
|
101 |
+
super().__init__()
|
102 |
+
self.language_mlp = MLP(config)
|
103 |
+
self.vision_mlp = MLP(config)
|
104 |
+
|
105 |
+
def forward(self, hidden_states: "torch.Tensor(B, L, D)", token_type_ids: "torch.LongTensor(B, L)"):
|
106 |
+
output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device)
|
107 |
+
vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
|
108 |
+
output[vision_token_mask] = self.vision_mlp(hidden_states[vision_token_mask])
|
109 |
+
output[language_token_mask] = self.language_mlp(hidden_states[language_token_mask])
|
110 |
+
return output
|
111 |
+
|
112 |
+
|
113 |
+
def attention_fn(
|
114 |
+
query_layer: "torch.tensor(B, H, L, HD)",
|
115 |
+
key_layer: "torch.tensor(B, H, L, HD)",
|
116 |
+
value_layer: "torch.tensor(B, H, L, HD)",
|
117 |
+
attention_mask: "torch.tensor(B, H, L, HD)",
|
118 |
+
*,
|
119 |
+
scaling_attention_score: bool = True,
|
120 |
+
attention_dropout: nn.Module = None
|
121 |
+
):
|
122 |
+
attention_mask_bool = (attention_mask == 0)
|
123 |
+
is_low_triangle = (attention_mask_bool == torch.ones_like(attention_mask_bool, dtype=torch.float).tril()).all()
|
124 |
+
is_full = (attention_mask_bool > 0).all()
|
125 |
+
if not (int(torch.__version__.split('.')[0]) >= 2):
|
126 |
+
warnings.warn("It's recommended to use torch2.0 or higher.")
|
127 |
+
if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score and (is_full or is_low_triangle):
|
128 |
+
dropout_p = 0. if attention_dropout is None or not attention_dropout.training else attention_dropout.p
|
129 |
+
return torch.nn.functional.scaled_dot_product_attention(
|
130 |
+
query_layer, key_layer, value_layer,
|
131 |
+
attn_mask=None,
|
132 |
+
dropout_p=dropout_p,
|
133 |
+
is_causal=not is_full
|
134 |
+
)
|
135 |
+
else:
|
136 |
+
if scaling_attention_score:
|
137 |
+
query_layer = query_layer / math.sqrt(query_layer.shape[-1])
|
138 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
139 |
+
attention_scores = attention_scores + attention_mask
|
140 |
+
attention_scores = nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).to(query_layer.dtype)
|
141 |
+
if attention_dropout is not None:
|
142 |
+
attention_scores = attention_dropout(attention_scores)
|
143 |
+
context_layer = torch.matmul(attention_scores, value_layer)
|
144 |
+
return context_layer
|
145 |
+
|
146 |
+
|
147 |
+
class VisionExpertAttention(nn.Module):
|
148 |
+
def __init__(self, config):
|
149 |
+
super().__init__()
|
150 |
+
self.config = config
|
151 |
+
self.hidden_size = config.hidden_size
|
152 |
+
self.num_heads = config.num_attention_heads
|
153 |
+
self.head_dim = self.hidden_size // self.num_heads
|
154 |
+
self.max_position_embeddings = config.max_position_embeddings
|
155 |
+
|
156 |
+
# self.rotary_emb = RotaryEmbedding(self.hidden_size // self.num_heads)
|
157 |
+
self.rotary_emb = FastRotaryEmbedding(dim=self.head_dim, pos_idx_in_fp32=False)
|
158 |
+
self.vision_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
|
159 |
+
self.vision_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
160 |
+
self.language_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
|
161 |
+
self.language_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
162 |
+
|
163 |
+
def _transpose_for_scores(self, tensor):
|
164 |
+
"""Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
|
165 |
+
new_tensor_shape = tensor.size()[:-1] + (self.num_heads, self.head_dim)
|
166 |
+
tensor = tensor.view(*new_tensor_shape)
|
167 |
+
return tensor.permute(0, 2, 1, 3)
|
168 |
+
|
169 |
+
def forward(
|
170 |
+
self,
|
171 |
+
hidden_states: torch.Tensor,
|
172 |
+
token_type_ids: torch.LongTensor,
|
173 |
+
position_ids: torch.LongTensor,
|
174 |
+
attention_mask: Optional[torch.Tensor] = None,
|
175 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
176 |
+
output_attentions: bool = False,
|
177 |
+
use_cache: bool = False,
|
178 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
179 |
+
bsz, q_len, _ = hidden_states.size()
|
180 |
+
vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
|
181 |
+
|
182 |
+
shape = list(hidden_states.shape)
|
183 |
+
shape[-1] = shape[-1] * 3
|
184 |
+
mixed_raw_layer = torch.empty(shape, dtype=hidden_states.dtype, device=hidden_states.device)
|
185 |
+
mixed_raw_layer[vision_token_mask] = self.vision_expert_query_key_value(hidden_states[vision_token_mask])
|
186 |
+
mixed_raw_layer[language_token_mask] = self.language_expert_query_key_value(hidden_states[language_token_mask])
|
187 |
+
|
188 |
+
query_states, key_states, value_states = torch.split(mixed_raw_layer, self.hidden_size, dim=-1)
|
189 |
+
query_states = self._transpose_for_scores(query_states) # B, H, L, HD
|
190 |
+
key_states = self._transpose_for_scores(key_states) # B, H, L, HD
|
191 |
+
value_states = self._transpose_for_scores(value_states) # B, H, L, HD
|
192 |
+
|
193 |
+
kv_seq_len = key_states.shape[-2]
|
194 |
+
if past_key_value is not None:
|
195 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
196 |
+
|
197 |
+
query_states, key_states = self.rotary_emb(query_states, key_states, position_ids=position_ids, max_seqlen=position_ids.max() + 1)
|
198 |
+
|
199 |
+
if past_key_value is not None:
|
200 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
201 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
202 |
+
|
203 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
204 |
+
|
205 |
+
context_layer = attention_fn(
|
206 |
+
query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
|
207 |
+
scaling_attention_score=True, attention_dropout=None)
|
208 |
+
if context_layer.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
209 |
+
raise ValueError(
|
210 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
211 |
+
f" {context_layer.size()}"
|
212 |
+
)
|
213 |
+
context_layer = context_layer.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
|
214 |
+
|
215 |
+
attn_output = torch.empty(context_layer.shape, dtype=hidden_states.dtype, device=hidden_states.device)
|
216 |
+
attn_output[vision_token_mask] = self.vision_expert_dense(context_layer[vision_token_mask])
|
217 |
+
attn_output[language_token_mask] = self.language_expert_dense(context_layer[language_token_mask])
|
218 |
+
|
219 |
+
if output_attentions:
|
220 |
+
warnings.warn("output_attentions is not implemented.")
|
221 |
+
|
222 |
+
return attn_output, None, past_key_value
|
223 |
+
|
224 |
+
|
225 |
+
class CogVLMDecoderLayer(nn.Module):
|
226 |
+
def __init__(self, config):
|
227 |
+
super().__init__()
|
228 |
+
self.hidden_size = config.hidden_size
|
229 |
+
self.self_attn = VisionExpertAttention(config=config)
|
230 |
+
self.mlp = VisionExpertMLP(config)
|
231 |
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
232 |
+
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
233 |
+
|
234 |
+
def forward(
|
235 |
+
self,
|
236 |
+
hidden_states: torch.Tensor,
|
237 |
+
token_type_ids: torch.LongTensor,
|
238 |
+
position_ids: torch.LongTensor,
|
239 |
+
attention_mask: Optional[torch.Tensor] = None,
|
240 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
241 |
+
output_attentions: Optional[bool] = False,
|
242 |
+
use_cache: Optional[bool] = False,
|
243 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
244 |
+
residual = hidden_states
|
245 |
+
|
246 |
+
hidden_states = self.input_layernorm(hidden_states)
|
247 |
+
|
248 |
+
# Self Attention
|
249 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
250 |
+
hidden_states=hidden_states,
|
251 |
+
token_type_ids=token_type_ids,
|
252 |
+
position_ids=position_ids,
|
253 |
+
attention_mask=attention_mask,
|
254 |
+
past_key_value=past_key_value,
|
255 |
+
output_attentions=output_attentions,
|
256 |
+
use_cache=use_cache,
|
257 |
+
)
|
258 |
+
hidden_states = residual + hidden_states
|
259 |
+
|
260 |
+
# Fully Connected
|
261 |
+
residual = hidden_states
|
262 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
263 |
+
hidden_states = self.mlp(hidden_states, token_type_ids=token_type_ids)
|
264 |
+
hidden_states = residual + hidden_states
|
265 |
+
|
266 |
+
outputs = (hidden_states,)
|
267 |
+
|
268 |
+
if output_attentions:
|
269 |
+
outputs += (self_attn_weights,)
|
270 |
+
|
271 |
+
if use_cache:
|
272 |
+
outputs += (present_key_value,)
|
273 |
+
|
274 |
+
return outputs # type: ignore
|
275 |
+
|
276 |
+
|
277 |
+
class CogVLMPreTrainedModel(PreTrainedModel):
|
278 |
+
config_class = CogVLMConfig
|
279 |
+
base_model_prefix = "model"
|
280 |
+
supports_gradient_checkpointing = False
|
281 |
+
_no_split_modules = ["CogVLMDecoderLayer", "TransformerLayer"]
|
282 |
+
_skip_keys_device_placement = "past_key_values"
|
283 |
+
|
284 |
+
def _init_weights(self, module):
|
285 |
+
std = self.config.initializer_range
|
286 |
+
if isinstance(module, nn.Linear):
|
287 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
288 |
+
if module.bias is not None:
|
289 |
+
module.bias.data.zero_()
|
290 |
+
elif isinstance(module, nn.Embedding):
|
291 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
292 |
+
if module.padding_idx is not None:
|
293 |
+
module.weight.data[module.padding_idx].zero_()
|
294 |
+
|
295 |
+
|
296 |
+
def is_empty(images_list: Optional[List[List[torch.Tensor]]]):
|
297 |
+
if images_list is None or len(images_list) == 0:
|
298 |
+
return True
|
299 |
+
for image_list in images_list:
|
300 |
+
if len(image_list):
|
301 |
+
return False
|
302 |
+
return True
|
303 |
+
|
304 |
+
|
305 |
+
def build_position_ids(x: "torch.BoolTensor(B, L)", attention_mask: Optional["torch.BoolTensor(B, L)"] = None) -> "torch.LongTensor(B, L)":
|
306 |
+
if attention_mask is not None:
|
307 |
+
tmp = x.clone()
|
308 |
+
tmp[~(attention_mask.bool())] = -1
|
309 |
+
else:
|
310 |
+
tmp = x.clone()
|
311 |
+
# image boi eoi token as LANGUAGE_TOKEN_TYPE
|
312 |
+
is_boi_eoi = torch.zeros_like(x, dtype=torch.bool)
|
313 |
+
is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)
|
314 |
+
is_boi_eoi[:, 0] |= (tmp[:, 0] == VISION_TOKEN_TYPE)
|
315 |
+
is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE)
|
316 |
+
is_boi_eoi[:, -1] |= (tmp[:, -1] == VISION_TOKEN_TYPE)
|
317 |
+
tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
|
318 |
+
# final position ids
|
319 |
+
y = torch.zeros_like(x, dtype=torch.long)
|
320 |
+
y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | ((tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE))
|
321 |
+
y = y.cumsum(dim=-1)
|
322 |
+
return y
|
323 |
+
|
324 |
+
|
325 |
+
class CogVLMModel(CogVLMPreTrainedModel):
|
326 |
+
def __init__(self, config):
|
327 |
+
super().__init__(config)
|
328 |
+
self.padding_idx = config.pad_token_id
|
329 |
+
self.vocab_size = config.vocab_size
|
330 |
+
|
331 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
332 |
+
self.layers = nn.ModuleList([CogVLMDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
333 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
334 |
+
|
335 |
+
self.vision = EVA2CLIPModel(config)
|
336 |
+
|
337 |
+
self.gradient_checkpointing = False
|
338 |
+
# Initialize weights and apply final processing
|
339 |
+
self.post_init()
|
340 |
+
|
341 |
+
def encode_images(self, images: List[List[torch.Tensor]]) -> torch.Tensor:
|
342 |
+
images_list, images = images, []
|
343 |
+
|
344 |
+
images = []
|
345 |
+
for image_list in images_list:
|
346 |
+
for image in image_list:
|
347 |
+
images.append(image)
|
348 |
+
|
349 |
+
images = torch.stack(images)
|
350 |
+
images_features = self.vision(images)
|
351 |
+
return images_features
|
352 |
+
|
353 |
+
def forward(
|
354 |
+
self,
|
355 |
+
input_ids: torch.LongTensor = None,
|
356 |
+
images: List[List[torch.Tensor]] = None,
|
357 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
358 |
+
attention_mask: Optional[torch.Tensor] = None,
|
359 |
+
position_ids: Optional[torch.LongTensor] = None,
|
360 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
361 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
362 |
+
use_cache: Optional[bool] = None,
|
363 |
+
output_attentions: Optional[bool] = None,
|
364 |
+
output_hidden_states: Optional[bool] = None,
|
365 |
+
return_dict: Optional[bool] = None,
|
366 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
367 |
+
"""take care of image_encode, token_type_ids, position_ids and (attention_mask = None is fine)"""
|
368 |
+
|
369 |
+
if past_key_values is not None:
|
370 |
+
pass # generate mode with past_key_values. the image features are already mapped
|
371 |
+
else:
|
372 |
+
# not allow for inputs_embeds, because we want to process image feature
|
373 |
+
assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
|
374 |
+
if not is_empty(images): # multi-modality
|
375 |
+
assert token_type_ids is not None, f"multi-modality requires `token_type_ids`!"
|
376 |
+
assert len(input_ids) == len(images), f"{len(input_ids)} {len(images)}"
|
377 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
378 |
+
images_features = self.encode_images(images)
|
379 |
+
images_features = rearrange(images_features, 'b n d -> (b n) d')
|
380 |
+
images_features = images_features.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
|
381 |
+
inputs_embeds = inputs_embeds.index_put([token_type_ids == VISION_TOKEN_TYPE], images_features)
|
382 |
+
else: # single-modality
|
383 |
+
if token_type_ids is None:
|
384 |
+
token_type_ids = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device) * LANGUAGE_TOKEN_TYPE
|
385 |
+
assert not (token_type_ids == VISION_TOKEN_TYPE).any(), f"{(token_type_ids == VISION_TOKEN_TYPE).sum()}"
|
386 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
387 |
+
|
388 |
+
if position_ids is None:
|
389 |
+
position_ids = build_position_ids(token_type_ids, attention_mask)
|
390 |
+
input_ids = None
|
391 |
+
|
392 |
+
return self.llm_forward(
|
393 |
+
input_ids=input_ids,
|
394 |
+
token_type_ids=token_type_ids,
|
395 |
+
attention_mask=attention_mask,
|
396 |
+
position_ids=position_ids,
|
397 |
+
past_key_values=past_key_values,
|
398 |
+
inputs_embeds=inputs_embeds,
|
399 |
+
use_cache=use_cache,
|
400 |
+
output_attentions=output_attentions,
|
401 |
+
output_hidden_states=output_hidden_states,
|
402 |
+
return_dict=return_dict,
|
403 |
+
)
|
404 |
+
|
405 |
+
def llm_forward(
|
406 |
+
self,
|
407 |
+
input_ids: torch.LongTensor = None,
|
408 |
+
token_type_ids: torch.LongTensor = None,
|
409 |
+
attention_mask: Optional[torch.Tensor] = None,
|
410 |
+
position_ids: Optional[torch.LongTensor] = None,
|
411 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
412 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
413 |
+
use_cache: Optional[bool] = None,
|
414 |
+
output_attentions: Optional[bool] = None,
|
415 |
+
output_hidden_states: Optional[bool] = None,
|
416 |
+
return_dict: Optional[bool] = None,
|
417 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
418 |
+
"""largely copy from llama forward and adapt for cogvlm with `token_type_ids`"""
|
419 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
420 |
+
output_hidden_states = (
|
421 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
422 |
+
)
|
423 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
424 |
+
|
425 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
426 |
+
|
427 |
+
# retrieve input_ids and inputs_embeds
|
428 |
+
if input_ids is not None and inputs_embeds is not None:
|
429 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
430 |
+
elif input_ids is not None:
|
431 |
+
batch_size, seq_length = input_ids.shape
|
432 |
+
elif inputs_embeds is not None:
|
433 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
434 |
+
else:
|
435 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
436 |
+
|
437 |
+
seq_length_with_past = seq_length
|
438 |
+
past_key_values_length = 0
|
439 |
+
|
440 |
+
if past_key_values is not None:
|
441 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
442 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
443 |
+
|
444 |
+
if position_ids is None:
|
445 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
446 |
+
position_ids = torch.arange(
|
447 |
+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
448 |
+
)
|
449 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
450 |
+
else:
|
451 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
452 |
+
|
453 |
+
if inputs_embeds is None:
|
454 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
455 |
+
# embed positions
|
456 |
+
if attention_mask is None:
|
457 |
+
attention_mask = torch.ones(
|
458 |
+
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
459 |
+
)
|
460 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
461 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
462 |
+
)
|
463 |
+
|
464 |
+
hidden_states = inputs_embeds
|
465 |
+
|
466 |
+
# decoder layers
|
467 |
+
all_hidden_states = () if output_hidden_states else None
|
468 |
+
all_self_attns = () if output_attentions else None
|
469 |
+
next_decoder_cache = () if use_cache else None
|
470 |
+
|
471 |
+
for idx, decoder_layer in enumerate(self.layers):
|
472 |
+
if output_hidden_states:
|
473 |
+
all_hidden_states += (hidden_states,)
|
474 |
+
|
475 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
476 |
+
layer_outputs = decoder_layer(
|
477 |
+
hidden_states,
|
478 |
+
token_type_ids=token_type_ids,
|
479 |
+
attention_mask=attention_mask,
|
480 |
+
position_ids=position_ids,
|
481 |
+
past_key_value=past_key_value,
|
482 |
+
output_attentions=output_attentions,
|
483 |
+
use_cache=use_cache,
|
484 |
+
)
|
485 |
+
hidden_states = layer_outputs[0]
|
486 |
+
|
487 |
+
if use_cache:
|
488 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
489 |
+
|
490 |
+
if output_attentions:
|
491 |
+
all_self_attns += (layer_outputs[1],)
|
492 |
+
|
493 |
+
hidden_states = self.norm(hidden_states)
|
494 |
+
|
495 |
+
# add hidden states from the last decoder layer
|
496 |
+
if output_hidden_states:
|
497 |
+
all_hidden_states += (hidden_states,)
|
498 |
+
|
499 |
+
next_cache = next_decoder_cache if use_cache else None
|
500 |
+
if not return_dict:
|
501 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
502 |
+
return BaseModelOutputWithPast(
|
503 |
+
last_hidden_state=hidden_states,
|
504 |
+
past_key_values=next_cache,
|
505 |
+
hidden_states=all_hidden_states,
|
506 |
+
attentions=all_self_attns,
|
507 |
+
)
|
508 |
+
|
509 |
+
def get_input_embeddings(self):
|
510 |
+
return self.embed_tokens
|
511 |
+
|
512 |
+
def set_input_embeddings(self, value):
|
513 |
+
self.embed_tokens = value
|
514 |
+
|
515 |
+
# noinspection PyMethodMayBeStatic
|
516 |
+
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
517 |
+
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
518 |
+
# create causal mask
|
519 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
520 |
+
combined_attention_mask = None
|
521 |
+
if input_shape[-1] > 1:
|
522 |
+
combined_attention_mask = _make_causal_mask(
|
523 |
+
input_shape,
|
524 |
+
inputs_embeds.dtype,
|
525 |
+
device=inputs_embeds.device,
|
526 |
+
past_key_values_length=past_key_values_length,
|
527 |
+
)
|
528 |
+
|
529 |
+
if attention_mask is not None:
|
530 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
531 |
+
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
|
532 |
+
inputs_embeds.device
|
533 |
+
)
|
534 |
+
combined_attention_mask = (
|
535 |
+
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
536 |
+
)
|
537 |
+
|
538 |
+
return combined_attention_mask
|
539 |
+
|
540 |
+
|
541 |
+
def _history_to_prompt(signal_type, history, query):
|
542 |
+
if signal_type == 'base':
|
543 |
+
return query
|
544 |
+
elif signal_type == 'vqa':
|
545 |
+
answer_format = 'Short answer:'
|
546 |
+
elif signal_type == 'chat':
|
547 |
+
answer_format = 'Answer:'
|
548 |
+
else:
|
549 |
+
assert False, f"Unknown signal type {signal_type}"
|
550 |
+
|
551 |
+
prompt = ''
|
552 |
+
for i, (old_query, response) in enumerate(history):
|
553 |
+
prompt += 'Question: ' + old_query + " {} ".format(answer_format) + response + "\n"
|
554 |
+
prompt += 'Question: {} {}'.format(query, answer_format)
|
555 |
+
return prompt
|
556 |
+
|
557 |
+
|
558 |
+
class CogVLMForCausalLM(CogVLMPreTrainedModel):
|
559 |
+
_auto_class = "AutoModelForCausalLM"
|
560 |
+
|
561 |
+
def __init__(self, config):
|
562 |
+
super().__init__(config)
|
563 |
+
self.model = CogVLMModel(config)
|
564 |
+
self.vocab_size = config.vocab_size
|
565 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
566 |
+
|
567 |
+
# Initialize weights and apply final processing
|
568 |
+
self.post_init()
|
569 |
+
|
570 |
+
def get_input_embeddings(self):
|
571 |
+
return self.model.embed_tokens
|
572 |
+
|
573 |
+
def set_input_embeddings(self, value):
|
574 |
+
self.model.embed_tokens = value
|
575 |
+
|
576 |
+
def get_output_embeddings(self):
|
577 |
+
return self.lm_head
|
578 |
+
|
579 |
+
def set_output_embeddings(self, new_embeddings):
|
580 |
+
self.lm_head = new_embeddings
|
581 |
+
|
582 |
+
def set_decoder(self, decoder):
|
583 |
+
self.model = decoder
|
584 |
+
|
585 |
+
def get_decoder(self):
|
586 |
+
return self.model
|
587 |
+
|
588 |
+
def forward(
|
589 |
+
self,
|
590 |
+
input_ids: torch.LongTensor = None,
|
591 |
+
images: List[List[torch.Tensor]] = None,
|
592 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
593 |
+
attention_mask: Optional[torch.Tensor] = None,
|
594 |
+
position_ids: Optional[torch.LongTensor] = None,
|
595 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
596 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
597 |
+
use_cache: Optional[bool] = None,
|
598 |
+
output_attentions: Optional[bool] = None,
|
599 |
+
output_hidden_states: Optional[bool] = None,
|
600 |
+
return_dict: Optional[bool] = None,
|
601 |
+
labels: Optional[torch.LongTensor] = None,
|
602 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
603 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
604 |
+
output_hidden_states = (
|
605 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
606 |
+
)
|
607 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
608 |
+
|
609 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
610 |
+
outputs = self.model(
|
611 |
+
input_ids=input_ids,
|
612 |
+
images=images,
|
613 |
+
token_type_ids=token_type_ids,
|
614 |
+
attention_mask=attention_mask,
|
615 |
+
position_ids=position_ids,
|
616 |
+
past_key_values=past_key_values,
|
617 |
+
inputs_embeds=inputs_embeds,
|
618 |
+
use_cache=use_cache,
|
619 |
+
output_attentions=output_attentions,
|
620 |
+
output_hidden_states=output_hidden_states,
|
621 |
+
return_dict=return_dict,
|
622 |
+
)
|
623 |
+
|
624 |
+
hidden_states = outputs[0]
|
625 |
+
logits = self.lm_head(hidden_states)
|
626 |
+
logits = logits.float()
|
627 |
+
|
628 |
+
loss = None
|
629 |
+
if labels is not None:
|
630 |
+
# Shift so that tokens < n predict n
|
631 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
632 |
+
shift_labels = labels[..., 1:].contiguous()
|
633 |
+
# Flatten the tokens
|
634 |
+
loss_fct = CrossEntropyLoss()
|
635 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
636 |
+
shift_labels = shift_labels.view(-1)
|
637 |
+
# Enable model parallelism
|
638 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
639 |
+
loss = loss_fct(shift_logits, shift_labels)
|
640 |
+
|
641 |
+
if not return_dict:
|
642 |
+
output = (logits,) + outputs[1:]
|
643 |
+
return (loss,) + output if loss is not None else output
|
644 |
+
|
645 |
+
return CausalLMOutputWithPast(
|
646 |
+
loss=loss,
|
647 |
+
logits=logits,
|
648 |
+
past_key_values=outputs.past_key_values,
|
649 |
+
hidden_states=outputs.hidden_states,
|
650 |
+
attentions=outputs.attentions,
|
651 |
+
)
|
652 |
+
|
653 |
+
def _prepare_attention_mask_for_generation(
|
654 |
+
self,
|
655 |
+
inputs: torch.Tensor,
|
656 |
+
pad_token_id: Optional[int],
|
657 |
+
eos_token_id: Optional[Union[int, List[int]]],
|
658 |
+
) -> torch.LongTensor:
|
659 |
+
return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) # type: ignore
|
660 |
+
|
661 |
+
def prepare_inputs_for_generation(
|
662 |
+
self, input_ids, token_type_ids, images=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
663 |
+
):
|
664 |
+
# build position_ids if needed
|
665 |
+
position_ids = kwargs.get("position_ids", None)
|
666 |
+
if position_ids is None:
|
667 |
+
position_ids = build_position_ids(token_type_ids, attention_mask)
|
668 |
+
|
669 |
+
if past_key_values:
|
670 |
+
input_ids = input_ids[:, -1:]
|
671 |
+
token_type_ids = token_type_ids[:, -1:]
|
672 |
+
position_ids = position_ids[:, -1:]
|
673 |
+
|
674 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
675 |
+
if inputs_embeds is not None and past_key_values is None:
|
676 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
677 |
+
else:
|
678 |
+
model_inputs = {"input_ids": input_ids}
|
679 |
+
|
680 |
+
model_inputs.update(
|
681 |
+
{
|
682 |
+
"token_type_ids": token_type_ids,
|
683 |
+
"images": images,
|
684 |
+
"position_ids": position_ids,
|
685 |
+
"past_key_values": past_key_values,
|
686 |
+
"use_cache": kwargs.get("use_cache"),
|
687 |
+
"attention_mask": attention_mask,
|
688 |
+
}
|
689 |
+
)
|
690 |
+
return model_inputs
|
691 |
+
|
692 |
+
def _update_model_kwargs_for_generation(
|
693 |
+
self,
|
694 |
+
outputs: "ModelOutput",
|
695 |
+
model_kwargs: Dict[str, Any],
|
696 |
+
is_encoder_decoder: bool = False,
|
697 |
+
standardize_cache_format: bool = False,
|
698 |
+
) -> Dict[str, Any]:
|
699 |
+
# update past_key_values
|
700 |
+
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
701 |
+
outputs, standardize_cache_format=standardize_cache_format
|
702 |
+
)
|
703 |
+
if getattr(outputs, "state", None) is not None:
|
704 |
+
model_kwargs["state"] = outputs.state
|
705 |
+
|
706 |
+
# update token_type_ids with last value
|
707 |
+
if "token_type_ids" in model_kwargs:
|
708 |
+
token_type_ids = model_kwargs["token_type_ids"]
|
709 |
+
new_token_type_ids = torch.ones(size=(token_type_ids.shape[0], 1), dtype=token_type_ids.dtype, device=token_type_ids.device) * LANGUAGE_TOKEN_TYPE
|
710 |
+
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, new_token_type_ids], dim=-1)
|
711 |
+
|
712 |
+
if not is_encoder_decoder:
|
713 |
+
# update attention mask
|
714 |
+
if "attention_mask" in model_kwargs:
|
715 |
+
attention_mask = model_kwargs["attention_mask"]
|
716 |
+
model_kwargs["attention_mask"] = torch.cat(
|
717 |
+
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
718 |
+
)
|
719 |
+
else:
|
720 |
+
# update decoder attention mask
|
721 |
+
if "decoder_attention_mask" in model_kwargs:
|
722 |
+
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
|
723 |
+
model_kwargs["decoder_attention_mask"] = torch.cat(
|
724 |
+
[decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
|
725 |
+
dim=-1,
|
726 |
+
)
|
727 |
+
|
728 |
+
return model_kwargs
|
729 |
+
|
730 |
+
def _reorder_cache(self, past_key_values, beam_idx):
|
731 |
+
reordered_past = ()
|
732 |
+
for layer_past in past_key_values:
|
733 |
+
reordered_past += (
|
734 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
735 |
+
)
|
736 |
+
return reordered_past
|
737 |
+
|
738 |
+
def build_conversation_input_ids(
|
739 |
+
self,
|
740 |
+
tokenizer: "PreTrainedTokenizer",
|
741 |
+
*,
|
742 |
+
query: str,
|
743 |
+
history: Optional[List[Tuple[str, str]]] = None,
|
744 |
+
images: Optional[List["PIL.Image"]] = None,
|
745 |
+
template_version: Optional[Literal["base", "chat", "vqa"]] = None,
|
746 |
+
):
|
747 |
+
image_size: int = self.config.vision_config['image_size']
|
748 |
+
patch_size: int = self.config.vision_config['patch_size']
|
749 |
+
template_version = template_version or self.config.template_version
|
750 |
+
assert images is None or len(images) <= 1, f"not support multi images by now."
|
751 |
+
history = history or []
|
752 |
+
text = _history_to_prompt(template_version, history, query)
|
753 |
+
|
754 |
+
input_ids = [tokenizer.bos_token_id]
|
755 |
+
token_type_ids = [LANGUAGE_TOKEN_TYPE]
|
756 |
+
if images is not None and len(images) == 1:
|
757 |
+
# vision
|
758 |
+
transform = transforms.Compose(
|
759 |
+
[
|
760 |
+
transforms.Resize(
|
761 |
+
(image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC
|
762 |
+
),
|
763 |
+
transforms.ToTensor(),
|
764 |
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
765 |
+
]
|
766 |
+
)
|
767 |
+
images = [transform(images[0])]
|
768 |
+
# language
|
769 |
+
vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2
|
770 |
+
input_ids += [tokenizer.pad_token_id] * vision_token_num
|
771 |
+
token_type_ids += [VISION_TOKEN_TYPE] * vision_token_num
|
772 |
+
text_ids = tokenizer.encode(text, add_special_tokens=False)
|
773 |
+
|
774 |
+
input_ids += text_ids
|
775 |
+
token_type_ids += [LANGUAGE_TOKEN_TYPE] * len(text_ids)
|
776 |
+
attention_mask = [1] * len(input_ids)
|
777 |
+
|
778 |
+
return {
|
779 |
+
'input_ids': torch.tensor(input_ids, dtype=torch.long),
|
780 |
+
'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
|
781 |
+
'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
|
782 |
+
'images': images,
|
783 |
+
}
|
util.py
ADDED
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from einops import rearrange, repeat
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
import triton
|
8 |
+
import triton.language as tl
|
9 |
+
|
10 |
+
|
11 |
+
# @triton.autotune(
|
12 |
+
# configs=[
|
13 |
+
# triton.Config({"BLOCK_M": 2}),
|
14 |
+
# triton.Config({"BLOCK_M": 4}),
|
15 |
+
# triton.Config({"BLOCK_M": 8}),
|
16 |
+
# triton.Config({"BLOCK_M": 16}),
|
17 |
+
# ],
|
18 |
+
# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"],
|
19 |
+
# )
|
20 |
+
@triton.jit
|
21 |
+
def rotary_kernel(
|
22 |
+
OUT, # Pointers to matrices
|
23 |
+
X,
|
24 |
+
COS,
|
25 |
+
SIN,
|
26 |
+
CU_SEQLENS,
|
27 |
+
SEQLEN_OFFSETS, # this could be int or a pointer
|
28 |
+
# Matrix dimensions
|
29 |
+
seqlen,
|
30 |
+
nheads,
|
31 |
+
rotary_dim,
|
32 |
+
seqlen_ro,
|
33 |
+
CACHE_KEY_SEQLEN,
|
34 |
+
# strides
|
35 |
+
stride_out_batch,
|
36 |
+
stride_out_nheads,
|
37 |
+
stride_out_seqlen,
|
38 |
+
stride_out_headdim,
|
39 |
+
stride_x_batch,
|
40 |
+
stride_x_nheads,
|
41 |
+
stride_x_seqlen,
|
42 |
+
stride_x_headdim,
|
43 |
+
# Meta-parameters
|
44 |
+
BLOCK_K: tl.constexpr,
|
45 |
+
IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
|
46 |
+
IS_VARLEN: tl.constexpr,
|
47 |
+
INTERLEAVED: tl.constexpr,
|
48 |
+
CONJUGATE: tl.constexpr,
|
49 |
+
BLOCK_M: tl.constexpr,
|
50 |
+
):
|
51 |
+
pid_m = tl.program_id(axis=0)
|
52 |
+
pid_batch = tl.program_id(axis=1)
|
53 |
+
pid_head = tl.program_id(axis=2)
|
54 |
+
rotary_dim_half = rotary_dim // 2
|
55 |
+
|
56 |
+
if not IS_VARLEN:
|
57 |
+
X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads
|
58 |
+
OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
|
59 |
+
COS = COS + pid_batch * seqlen_ro * rotary_dim_half
|
60 |
+
SIN = SIN + pid_batch * seqlen_ro * rotary_dim_half
|
61 |
+
else:
|
62 |
+
start_idx = tl.load(CU_SEQLENS + pid_batch)
|
63 |
+
seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
|
64 |
+
X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads
|
65 |
+
OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads
|
66 |
+
|
67 |
+
if pid_m * BLOCK_M >= seqlen:
|
68 |
+
return
|
69 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
70 |
+
if not IS_SEQLEN_OFFSETS_TENSOR:
|
71 |
+
rm_cs = rm + SEQLEN_OFFSETS
|
72 |
+
else:
|
73 |
+
rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
|
74 |
+
rk = tl.arange(0, BLOCK_K)
|
75 |
+
rk_half = tl.arange(0, BLOCK_K // 2)
|
76 |
+
|
77 |
+
if not INTERLEAVED:
|
78 |
+
# Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT
|
79 |
+
X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim)
|
80 |
+
COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
|
81 |
+
SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
|
82 |
+
cos = tl.load(
|
83 |
+
COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0
|
84 |
+
)
|
85 |
+
sin = tl.load(
|
86 |
+
SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0
|
87 |
+
)
|
88 |
+
x0 = tl.load(
|
89 |
+
X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0
|
90 |
+
)
|
91 |
+
x1 = tl.load(
|
92 |
+
X + rotary_dim_half * stride_x_headdim,
|
93 |
+
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
|
94 |
+
other=0.0,
|
95 |
+
)
|
96 |
+
if CONJUGATE:
|
97 |
+
sin = -sin
|
98 |
+
o0 = x0 * cos - x1 * sin
|
99 |
+
o1 = x0 * sin + x1 * cos
|
100 |
+
# write back result
|
101 |
+
OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)
|
102 |
+
tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))
|
103 |
+
tl.store(
|
104 |
+
OUT + rotary_dim_half * stride_out_headdim,
|
105 |
+
o1,
|
106 |
+
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
|
107 |
+
)
|
108 |
+
else:
|
109 |
+
# We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow.
|
110 |
+
# Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...].
|
111 |
+
# Loading x0 will be fast but x1 will be slow.
|
112 |
+
# Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...].
|
113 |
+
# Then we do the calculation and use tl.where to pick put the right outputs for the even
|
114 |
+
# and for the odd indices.
|
115 |
+
rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...
|
116 |
+
rk_repeat = tl.arange(0, BLOCK_K) // 2
|
117 |
+
X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim)
|
118 |
+
X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim)
|
119 |
+
COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
|
120 |
+
SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
|
121 |
+
cos = tl.load(
|
122 |
+
COS,
|
123 |
+
mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
|
124 |
+
other=1.0,
|
125 |
+
).to(tl.float32)
|
126 |
+
sin = tl.load(
|
127 |
+
SIN,
|
128 |
+
mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
|
129 |
+
other=0.0,
|
130 |
+
).to(tl.float32)
|
131 |
+
x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(
|
132 |
+
tl.float32
|
133 |
+
)
|
134 |
+
x1 = tl.load(
|
135 |
+
X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0
|
136 |
+
).to(tl.float32)
|
137 |
+
if CONJUGATE:
|
138 |
+
sin = -sin
|
139 |
+
x0_cos = x0 * cos
|
140 |
+
x1_sin = x1 * sin
|
141 |
+
out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)
|
142 |
+
OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)
|
143 |
+
tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))
|
144 |
+
|
145 |
+
|
146 |
+
def apply_rotary(
|
147 |
+
x: torch.Tensor,
|
148 |
+
cos: torch.Tensor,
|
149 |
+
sin: torch.Tensor,
|
150 |
+
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
151 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
152 |
+
max_seqlen: Optional[int] = None,
|
153 |
+
interleaved=False,
|
154 |
+
inplace=False,
|
155 |
+
conjugate=False,
|
156 |
+
) -> torch.Tensor:
|
157 |
+
"""
|
158 |
+
Arguments:
|
159 |
+
x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
|
160 |
+
else (total_seqlen, nheads, headdim).
|
161 |
+
cos: (seqlen_ro, rotary_dim / 2)
|
162 |
+
sin: (seqlen_ro, rotary_dim / 2)
|
163 |
+
seqlen_offsets: integer or integer tensor of size (batch,)
|
164 |
+
cu_seqlens: (batch + 1,) or None
|
165 |
+
max_seqlen: int
|
166 |
+
Returns:
|
167 |
+
y: (batch, seqlen, nheads, headdim)
|
168 |
+
"""
|
169 |
+
|
170 |
+
batch, nheads, seqlen, headdim = x.shape
|
171 |
+
|
172 |
+
batch_ro, seqlen_ro, rotary_dim = cos.shape
|
173 |
+
|
174 |
+
assert batch == batch_ro
|
175 |
+
assert sin.shape == cos.shape
|
176 |
+
rotary_dim *= 2
|
177 |
+
assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
|
178 |
+
assert headdim <= 256, "Only support headdim <= 256"
|
179 |
+
|
180 |
+
assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
|
181 |
+
|
182 |
+
assert (
|
183 |
+
cos.dtype == sin.dtype
|
184 |
+
), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
|
185 |
+
assert (
|
186 |
+
x.dtype == cos.dtype
|
187 |
+
), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
|
188 |
+
|
189 |
+
cos, sin = cos.contiguous(), sin.contiguous()
|
190 |
+
if isinstance(seqlen_offsets, torch.Tensor):
|
191 |
+
assert seqlen_offsets.shape == (batch,)
|
192 |
+
assert seqlen_offsets.dtype in [torch.int32, torch.int64]
|
193 |
+
seqlen_offsets = seqlen_offsets.contiguous()
|
194 |
+
else:
|
195 |
+
assert seqlen_offsets + seqlen <= seqlen_ro
|
196 |
+
|
197 |
+
output = torch.empty_like(x) if not inplace else x
|
198 |
+
if rotary_dim < headdim and not inplace:
|
199 |
+
output[..., rotary_dim:].copy_(x[..., rotary_dim:])
|
200 |
+
|
201 |
+
BLOCK_K = (
|
202 |
+
32
|
203 |
+
if rotary_dim <= 32
|
204 |
+
else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
|
205 |
+
)
|
206 |
+
grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa
|
207 |
+
BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)
|
208 |
+
|
209 |
+
# Need this, otherwise Triton tries to launch from cuda:0 and we get
|
210 |
+
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
|
211 |
+
with torch.cuda.device(x.device.index):
|
212 |
+
rotary_kernel[grid](
|
213 |
+
output, # data ptrs
|
214 |
+
x,
|
215 |
+
cos,
|
216 |
+
sin,
|
217 |
+
cu_seqlens,
|
218 |
+
seqlen_offsets,
|
219 |
+
seqlen, # shapes
|
220 |
+
nheads,
|
221 |
+
rotary_dim,
|
222 |
+
seqlen_ro,
|
223 |
+
seqlen // 128, # key for triton cache (limit number of compilations)
|
224 |
+
output.stride(0), # batch_strides
|
225 |
+
output.stride(-3), # nheads_stride
|
226 |
+
output.stride(-2), # seqlen_stride
|
227 |
+
output.stride(-1), # headdim_stride
|
228 |
+
x.stride(0), # batch_strides
|
229 |
+
x.stride(-3), # nheads stride
|
230 |
+
x.stride(-2), # seqlen stride
|
231 |
+
x.stride(-1), # headdim stride
|
232 |
+
BLOCK_K,
|
233 |
+
isinstance(seqlen_offsets, torch.Tensor),
|
234 |
+
False,
|
235 |
+
interleaved,
|
236 |
+
conjugate,
|
237 |
+
BLOCK_M,
|
238 |
+
)
|
239 |
+
return output
|
240 |
+
|
241 |
+
|
242 |
+
class ApplyRotaryEmb(torch.autograd.Function):
|
243 |
+
@staticmethod
|
244 |
+
def forward(
|
245 |
+
ctx,
|
246 |
+
x,
|
247 |
+
cos,
|
248 |
+
sin,
|
249 |
+
interleaved=False,
|
250 |
+
inplace=False,
|
251 |
+
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
252 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
253 |
+
max_seqlen: Optional[int] = None,
|
254 |
+
):
|
255 |
+
out = apply_rotary(
|
256 |
+
x,
|
257 |
+
cos,
|
258 |
+
sin,
|
259 |
+
seqlen_offsets=seqlen_offsets,
|
260 |
+
cu_seqlens=cu_seqlens,
|
261 |
+
max_seqlen=max_seqlen,
|
262 |
+
interleaved=interleaved,
|
263 |
+
inplace=inplace,
|
264 |
+
)
|
265 |
+
if isinstance(seqlen_offsets, int):
|
266 |
+
ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
|
267 |
+
ctx.seqlen_offsets = seqlen_offsets
|
268 |
+
else:
|
269 |
+
ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
|
270 |
+
ctx.seqlen_offsets = None
|
271 |
+
ctx.interleaved = interleaved
|
272 |
+
ctx.inplace = inplace
|
273 |
+
ctx.max_seqlen = max_seqlen
|
274 |
+
return out if not inplace else x
|
275 |
+
|
276 |
+
@staticmethod
|
277 |
+
def backward(ctx, do):
|
278 |
+
seqlen_offsets = ctx.seqlen_offsets
|
279 |
+
if seqlen_offsets is None:
|
280 |
+
cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
|
281 |
+
else:
|
282 |
+
cos, sin, cu_seqlens = ctx.saved_tensors
|
283 |
+
# TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
|
284 |
+
# "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
|
285 |
+
if not ctx.interleaved and not ctx.inplace:
|
286 |
+
do = do.clone()
|
287 |
+
dx = apply_rotary(
|
288 |
+
do,
|
289 |
+
cos,
|
290 |
+
sin,
|
291 |
+
seqlen_offsets=seqlen_offsets,
|
292 |
+
cu_seqlens=cu_seqlens,
|
293 |
+
max_seqlen=ctx.max_seqlen,
|
294 |
+
interleaved=ctx.interleaved,
|
295 |
+
inplace=ctx.inplace,
|
296 |
+
conjugate=True,
|
297 |
+
)
|
298 |
+
return dx, None, None, None, None, None, None, None
|
299 |
+
|
300 |
+
|
301 |
+
def apply_rotary_emb(
|
302 |
+
x,
|
303 |
+
cos,
|
304 |
+
sin,
|
305 |
+
interleaved=False,
|
306 |
+
inplace=False,
|
307 |
+
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
308 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
309 |
+
max_seqlen: Optional[int] = None,
|
310 |
+
):
|
311 |
+
"""
|
312 |
+
Arguments:
|
313 |
+
x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
|
314 |
+
else (total_seqlen, nheads, headdim)
|
315 |
+
cos, sin: (seqlen_rotary, rotary_dim / 2)
|
316 |
+
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
317 |
+
of 1st half and 2nd half (GPT-NeoX style).
|
318 |
+
inplace: if True, apply rotary embedding in-place.
|
319 |
+
seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
|
320 |
+
Most commonly used in inference when we have KV cache.
|
321 |
+
cu_seqlens: (batch + 1,) or None
|
322 |
+
max_seqlen: int
|
323 |
+
Return:
|
324 |
+
out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
|
325 |
+
else (total_seqlen, nheads, headdim)
|
326 |
+
rotary_dim must be <= headdim
|
327 |
+
Apply rotary embedding to the first rotary_dim of x.
|
328 |
+
"""
|
329 |
+
return ApplyRotaryEmb.apply(
|
330 |
+
x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
|
331 |
+
)
|
332 |
+
|
333 |
+
|
334 |
+
# For backward compatibility
|
335 |
+
apply_rotary_emb_func = apply_rotary_emb
|
336 |
+
|
337 |
+
|
338 |
+
class FastRotaryEmbedding(torch.nn.Module):
|
339 |
+
"""
|
340 |
+
The rotary position embeddings from RoFormer_ (Su et. al).
|
341 |
+
A crucial insight from the method is that the query and keys are
|
342 |
+
transformed by rotation matrices which depend on the relative positions.
|
343 |
+
|
344 |
+
Other implementations are available in the Rotary Transformer repo_ and in
|
345 |
+
GPT-NeoX_, GPT-NeoX was an inspiration
|
346 |
+
|
347 |
+
.. _RoFormer: https://arxiv.org/abs/2104.09864
|
348 |
+
.. _repo: https://github.com/ZhuiyiTechnology/roformer
|
349 |
+
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
|
350 |
+
|
351 |
+
If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
|
352 |
+
A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
|
353 |
+
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
|
354 |
+
"""
|
355 |
+
|
356 |
+
def __init__(
|
357 |
+
self,
|
358 |
+
dim: int,
|
359 |
+
base=10000,
|
360 |
+
interleaved=False,
|
361 |
+
scale_base=None,
|
362 |
+
pos_idx_in_fp32=True,
|
363 |
+
device=None,
|
364 |
+
):
|
365 |
+
"""
|
366 |
+
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
367 |
+
of 1st half and 2nd half (GPT-NeoX style).
|
368 |
+
pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
|
369 |
+
otherwise they might be in lower precision.
|
370 |
+
This option was added because previously (before 2023-07-02), when we construct
|
371 |
+
the position indices, we use the dtype of self.inv_freq. In most cases this would
|
372 |
+
be fp32, but if the model is trained in pure bf16 (not mixed precision), then
|
373 |
+
self.inv_freq would be bf16, and the position indices are also in bf16.
|
374 |
+
Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
|
375 |
+
embeddings for some positions will coincide.
|
376 |
+
To maintain compatibility with models previously trained in pure bf16,
|
377 |
+
we add this option.
|
378 |
+
"""
|
379 |
+
super().__init__()
|
380 |
+
self.dim = dim
|
381 |
+
self.base = base
|
382 |
+
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
383 |
+
# Generate and save the inverse frequency buffer (non trainable)
|
384 |
+
inv_freq = self._compute_inv_freq(device)
|
385 |
+
self.register_buffer("inv_freq", inv_freq)
|
386 |
+
self.interleaved = interleaved
|
387 |
+
self.scale_base = scale_base
|
388 |
+
scale = (
|
389 |
+
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
390 |
+
if scale_base is not None
|
391 |
+
else None
|
392 |
+
)
|
393 |
+
self.register_buffer("scale", scale, persistent=False)
|
394 |
+
|
395 |
+
self._seq_len_cached = 0
|
396 |
+
self._cos_cached = None
|
397 |
+
self._sin_cached = None
|
398 |
+
self._cos_k_cached = None
|
399 |
+
self._sin_k_cached = None
|
400 |
+
self.cos = None
|
401 |
+
self.sin = None
|
402 |
+
|
403 |
+
def _compute_inv_freq(self, device=None):
|
404 |
+
return 1.0 / (
|
405 |
+
self.base
|
406 |
+
** (torch.arange(0, self.dim, 2, device=device) / self.dim)
|
407 |
+
# ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim)
|
408 |
+
)
|
409 |
+
|
410 |
+
def _update_cos_sin_cache(self, seqlen, position_id, device=None, dtype=None):
|
411 |
+
|
412 |
+
if (
|
413 |
+
seqlen > self._seq_len_cached
|
414 |
+
):
|
415 |
+
self._seq_len_cached = seqlen
|
416 |
+
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
|
417 |
+
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
|
418 |
+
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
|
419 |
+
if self.pos_idx_in_fp32:
|
420 |
+
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
421 |
+
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
|
422 |
+
# will be large. Having it in bf16 will lose a lot of precision and cause the
|
423 |
+
# cos & sin output to change significantly.
|
424 |
+
# We want to recompute self.inv_freq if it was not loaded in fp32
|
425 |
+
if self.inv_freq.dtype != torch.float32:
|
426 |
+
inv_freq = self._compute_inv_freq(device=device)
|
427 |
+
else:
|
428 |
+
inv_freq = self.inv_freq
|
429 |
+
else:
|
430 |
+
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
431 |
+
inv_freq = self.inv_freq
|
432 |
+
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
433 |
+
if self.scale is None:
|
434 |
+
self._cos_cached = torch.cos(freqs).to(dtype)
|
435 |
+
self._sin_cached = torch.sin(freqs).to(dtype)
|
436 |
+
|
437 |
+
else:
|
438 |
+
power = (
|
439 |
+
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
|
440 |
+
- seqlen // 2
|
441 |
+
) / self.scale_base
|
442 |
+
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
443 |
+
# We want the multiplication by scale to happen in fp32
|
444 |
+
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
445 |
+
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
446 |
+
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
447 |
+
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
448 |
+
|
449 |
+
def forward(
|
450 |
+
self,
|
451 |
+
q: torch.Tensor,
|
452 |
+
k: torch.Tensor,
|
453 |
+
position_ids: torch.Tensor,
|
454 |
+
max_seqlen,
|
455 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
456 |
+
"""
|
457 |
+
q: (batch, nheads, seqlen, headdim)
|
458 |
+
k: (batch, nheads, seqlen, headdim)
|
459 |
+
position_id: (batch, seqlen)
|
460 |
+
max_seqlen: int
|
461 |
+
layer_id: int
|
462 |
+
only if layer_id == 0, then update cons and sin
|
463 |
+
Apply rotary embedding *inplace* to q k.
|
464 |
+
"""
|
465 |
+
|
466 |
+
self._update_cos_sin_cache(max_seqlen, position_ids, device=q.device, dtype=q.dtype)
|
467 |
+
cos, sin = F.embedding(position_ids, self._cos_cached), F.embedding(position_ids, self._sin_cached)
|
468 |
+
|
469 |
+
q = apply_rotary_emb_func(
|
470 |
+
q,
|
471 |
+
cos,
|
472 |
+
sin,
|
473 |
+
interleaved=self.interleaved,
|
474 |
+
inplace=True
|
475 |
+
)
|
476 |
+
k = apply_rotary_emb_func(
|
477 |
+
k,
|
478 |
+
cos,
|
479 |
+
sin,
|
480 |
+
interleaved=self.interleaved,
|
481 |
+
inplace=True
|
482 |
+
)
|
483 |
+
return q, k
|
visual.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from argparse import Namespace
|
4 |
+
import xformers.ops as xops
|
5 |
+
from transformers.activations import ACT2FN
|
6 |
+
|
7 |
+
|
8 |
+
class PatchEmbedding(nn.Module):
|
9 |
+
def __init__(self, config):
|
10 |
+
super().__init__()
|
11 |
+
self.proj = nn.Conv2d(config.in_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size)
|
12 |
+
self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
|
13 |
+
self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size)
|
14 |
+
|
15 |
+
def forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L, D)":
|
16 |
+
x = self.proj(images)
|
17 |
+
x = x.flatten(2).transpose(1, 2)
|
18 |
+
cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
|
19 |
+
x = torch.cat((cls_token, x), dim=1)
|
20 |
+
x += self.position_embedding.weight.unsqueeze(0)
|
21 |
+
return x
|
22 |
+
|
23 |
+
|
24 |
+
class Attention(nn.Module):
|
25 |
+
def __init__(self, config):
|
26 |
+
super().__init__()
|
27 |
+
self.num_heads = config.num_heads
|
28 |
+
head_dim = config.hidden_size // config.num_heads
|
29 |
+
self.scale = head_dim ** -0.5
|
30 |
+
self.query_key_value = nn.Linear(config.hidden_size, config.hidden_size * 3)
|
31 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
32 |
+
self.output_dropout = torch.nn.Dropout(config.dropout_prob)
|
33 |
+
|
34 |
+
def forward(self, x: "tensor(B, L, D)") -> "tensor(B, L, D)":
|
35 |
+
B, L, _ = x.shape
|
36 |
+
qkv = self.query_key_value(x)
|
37 |
+
qkv = qkv.reshape(B, L, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4) # 3, B, L, H, D
|
38 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
39 |
+
|
40 |
+
out = xops.memory_efficient_attention(
|
41 |
+
q, k, v, scale=self.scale,
|
42 |
+
)
|
43 |
+
output = self.dense(out.view(B, L, -1))
|
44 |
+
output = self.output_dropout(output)
|
45 |
+
return output
|
46 |
+
|
47 |
+
def attention(self, q, k, v):
|
48 |
+
attn_weights = torch.matmul(q * self.scale, k.transpose(-2, -1))
|
49 |
+
attn_weights = attn_weights.softmax(dim=-1)
|
50 |
+
output = torch.matmul(attn_weights, v)
|
51 |
+
return output
|
52 |
+
|
53 |
+
|
54 |
+
class MLP(nn.Module):
|
55 |
+
def __init__(self, config):
|
56 |
+
super().__init__()
|
57 |
+
self.config = config
|
58 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
59 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
60 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
61 |
+
|
62 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
63 |
+
x = self.fc1(x)
|
64 |
+
x = self.activation_fn(x)
|
65 |
+
x = self.fc2(x)
|
66 |
+
return x
|
67 |
+
|
68 |
+
|
69 |
+
class TransformerLayer(nn.Module):
|
70 |
+
def __init__(self, config):
|
71 |
+
super().__init__()
|
72 |
+
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
73 |
+
self.attention = Attention(config)
|
74 |
+
self.mlp = MLP(config)
|
75 |
+
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
76 |
+
|
77 |
+
def forward(self, hidden_states):
|
78 |
+
attention_input = hidden_states
|
79 |
+
attention_output = self.input_layernorm(self.attention(attention_input))
|
80 |
+
hidden_states = attention_input + attention_output
|
81 |
+
mlp_input = hidden_states
|
82 |
+
mlp_output = self.post_attention_layernorm(self.mlp(mlp_input))
|
83 |
+
output = mlp_input + mlp_output
|
84 |
+
return output
|
85 |
+
|
86 |
+
|
87 |
+
class Transformer(nn.Module):
|
88 |
+
def __init__(self, config):
|
89 |
+
super().__init__()
|
90 |
+
self.layers = nn.ModuleList([TransformerLayer(config) for _ in range(config.num_hidden_layers)])
|
91 |
+
|
92 |
+
def forward(self, hidden_states):
|
93 |
+
for layer_module in self.layers:
|
94 |
+
hidden_states = layer_module(hidden_states)
|
95 |
+
return hidden_states
|
96 |
+
|
97 |
+
|
98 |
+
class GLU(nn.Module):
|
99 |
+
def __init__(self, config, in_features):
|
100 |
+
super().__init__()
|
101 |
+
self.linear_proj = nn.Linear(in_features, config.hidden_size, bias=False)
|
102 |
+
self.norm1 = nn.LayerNorm(config.hidden_size)
|
103 |
+
self.act1 = nn.GELU()
|
104 |
+
self.act2 = nn.functional.silu
|
105 |
+
self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
106 |
+
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
107 |
+
self.dense_4h_to_h = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
108 |
+
|
109 |
+
def forward(self, x):
|
110 |
+
x = self.linear_proj(x)
|
111 |
+
x = self.act1(self.norm1(x))
|
112 |
+
x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x)
|
113 |
+
x = self.dense_4h_to_h(x)
|
114 |
+
return x
|
115 |
+
|
116 |
+
|
117 |
+
class EVA2CLIPModel(nn.Module):
|
118 |
+
def __init__(self, config):
|
119 |
+
super().__init__()
|
120 |
+
vision_config = Namespace(**config.vision_config)
|
121 |
+
self.patch_embedding = PatchEmbedding(vision_config)
|
122 |
+
self.transformer = Transformer(vision_config)
|
123 |
+
self.linear_proj = GLU(config, in_features=vision_config.hidden_size)
|
124 |
+
self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
125 |
+
self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
126 |
+
|
127 |
+
def forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L, D)":
|
128 |
+
x = self.patch_embedding(images)
|
129 |
+
x = self.transformer(x)
|
130 |
+
x = x[:, 1:]
|
131 |
+
x = self.linear_proj(x)
|
132 |
+
boi = self.boi.expand(x.shape[0], -1, -1)
|
133 |
+
eoi = self.eoi.expand(x.shape[0], -1, -1)
|
134 |
+
x = torch.cat((boi, x, eoi), dim=1)
|
135 |
+
return x
|