Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +10 -0
- README.md +90 -5
- app.py +263 -268
- config/llama2/7B.json +1 -0
- config/llama2/tokenizer.model +3 -0
- data/__pycache__/conversation_lib.cpython-310.pyc +0 -0
- data/__pycache__/conversation_lib.cpython-39.pyc +0 -0
- data/__pycache__/fintune_dataset.cpython-310.pyc +0 -0
- data/__pycache__/fintune_dataset.cpython-39.pyc +0 -0
- data/__pycache__/imu_utils.cpython-310.pyc +0 -0
- data/__pycache__/imu_utils.cpython-39.pyc +0 -0
- data/__pycache__/video_utils.cpython-310.pyc +0 -0
- data/__pycache__/video_utils.cpython-39.pyc +0 -0
- data/conversation_lib.py +369 -0
- data/fintune_dataset.py +449 -0
- data/imu_utils.py +257 -0
- data/video_utils.py +204 -0
- demos/multi_turn_mm.py +300 -0
- lib/__pycache__/point_utils.cpython-310.pyc +0 -0
- lib/point_utils.py +191 -0
- lib/pointnet2/pointnet2_modules.py +160 -0
- lib/pointnet2/pointnet2_utils.py +290 -0
- lib/pointnet2/pytorch_utils.py +236 -0
- lib/pointnet2/setup.py +23 -0
- lib/pointnet2/src/ball_query.cpp +24 -0
- lib/pointnet2/src/ball_query_gpu.cu +67 -0
- lib/pointnet2/src/ball_query_gpu.h +15 -0
- lib/pointnet2/src/cuda_utils.h +15 -0
- lib/pointnet2/src/group_points.cpp +34 -0
- lib/pointnet2/src/group_points_gpu.cu +86 -0
- lib/pointnet2/src/group_points_gpu.h +22 -0
- lib/pointnet2/src/interpolate.cpp +53 -0
- lib/pointnet2/src/interpolate_gpu.cu +161 -0
- lib/pointnet2/src/interpolate_gpu.h +30 -0
- lib/pointnet2/src/pointnet2_api.cpp +24 -0
- lib/pointnet2/src/sampling.cpp +45 -0
- lib/pointnet2/src/sampling_gpu.cu +253 -0
- lib/pointnet2/src/sampling_gpu.h +29 -0
- model/LLM/__init__.py +1 -0
- model/LLM/__pycache__/__init__.cpython-310.pyc +0 -0
- model/LLM/__pycache__/__init__.cpython-39.pyc +0 -0
- model/LLM/__pycache__/onellm.cpython-310.pyc +0 -0
- model/LLM/__pycache__/onellm.cpython-39.pyc +0 -0
- model/LLM/onellm.py +495 -0
- model/__init__.py +0 -0
- model/__pycache__/__init__.cpython-310.pyc +0 -0
- model/__pycache__/__init__.cpython-39.pyc +0 -0
- model/__pycache__/components.cpython-39.pyc +0 -0
- model/__pycache__/meta.cpython-310.pyc +0 -0
- model/__pycache__/meta.cpython-39.pyc +0 -0
.gitignore
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
*.pyc
|
3 |
+
*.egg-info
|
4 |
+
dist
|
5 |
+
|
6 |
+
output
|
7 |
+
output_dir
|
8 |
+
*.pth
|
9 |
+
*.log
|
10 |
+
weights
|
README.md
CHANGED
@@ -1,15 +1,100 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
emoji: 🚀
|
4 |
colorFrom: red
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
|
13 |
-
The official demo for LLaMA-Adapter V2.
|
14 |
-
Please refer to our [arXiv paper](https://arxiv.org/abs/2303.16199) and [github](https://github.com/ZrrSkywalker/LLaMA-Adapter) for more details.
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: OneLLM
|
3 |
emoji: 🚀
|
4 |
colorFrom: red
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.7.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
+
# OneLLM: One Framework to Align All Modalities with Language
|
|
|
|
|
13 |
|
14 |
+
[[Project Page](https://onellm.csuhan.com)] [[Paper](#)] [[Web Demo](https://huggingface.co/spaces/csuhan/OneLLM)]
|
15 |
+
|
16 |
+
Authors: [Jiaming Han](), [Kaixiong Gong](), [Yiyuan Zhang](), [Jiaqi Wang](), [Kaipeng Zhang](), [Dahua Lin](), [Yu Qiao](), [Peng Gao](), [Xiangyu Yue]().
|
17 |
+
|
18 |
+
## News
|
19 |
+
|
20 |
+
- **2023.12.01** Release model weights and inference code.
|
21 |
+
|
22 |
+
## Contents
|
23 |
+
|
24 |
+
- [Install](#install)
|
25 |
+
- [Models](#models)
|
26 |
+
- [Demo](#demo)
|
27 |
+
|
28 |
+
<!-- - [Evaluation](#evaluation) -->
|
29 |
+
|
30 |
+
<!-- - [Training](#training) -->
|
31 |
+
|
32 |
+
### TODO
|
33 |
+
|
34 |
+
- [ ] Data
|
35 |
+
- [ ] Evaluation
|
36 |
+
- [ ] Training
|
37 |
+
|
38 |
+
### Install
|
39 |
+
|
40 |
+
1. Clone the repo into a local folder.
|
41 |
+
|
42 |
+
```bash
|
43 |
+
git clone https://github.com/csuhan/OneLLM
|
44 |
+
|
45 |
+
cd OneLLM
|
46 |
+
```
|
47 |
+
|
48 |
+
2. Install packages.
|
49 |
+
|
50 |
+
```bash
|
51 |
+
conda create -n onellm python=3.9 -y
|
52 |
+
conda activate onellm
|
53 |
+
|
54 |
+
pip install -r requirements.txt
|
55 |
+
|
56 |
+
# install pointnet
|
57 |
+
cd lib/pointnet2
|
58 |
+
python setup.py install
|
59 |
+
```
|
60 |
+
|
61 |
+
3. Install Apex. (Optional)
|
62 |
+
|
63 |
+
```bash
|
64 |
+
git clone https://github.com/NVIDIA/apex
|
65 |
+
cd apex
|
66 |
+
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
|
67 |
+
```
|
68 |
+
|
69 |
+
### Models
|
70 |
+
|
71 |
+
We provide a preview model at: [csuhan/OneLLM-7B](https://huggingface.co/csuhan/OneLLM-7B).
|
72 |
+
|
73 |
+
### Demo
|
74 |
+
|
75 |
+
**Huggingface Demo:** [csuhan/OneLLM](https://huggingface.co/spaces/csuhan/OneLLM).
|
76 |
+
|
77 |
+
**Local Demo:** Assume you have downloaded the weights to ${WEIGHTS_DIR}. Then run the following command to start a gradio demo locally.
|
78 |
+
|
79 |
+
```bash
|
80 |
+
python demos/multi_turn_mm.py --gpu_ids 0 --tokenizer_path config/llama2/tokenizer.model --llama_config config/llama2/7B.json --pretrained_path ${WEIGHTS_DIR}/consolidated.00-of-01.pth
|
81 |
+
```
|
82 |
+
|
83 |
+
<!-- ### Evaluation -->
|
84 |
+
|
85 |
+
<!-- ### Training -->
|
86 |
+
|
87 |
+
## Citation
|
88 |
+
|
89 |
+
```
|
90 |
+
@article{han2023onellm,
|
91 |
+
title={OneLLM: One Framework to Align All Modalities with Language},
|
92 |
+
author={Han, Jiaming and Gong, Kaixiong and Zhang, Yiyuan and Wang, Jiaqi and Zhang, Kaipeng and Lin, Dahua and Qiao, Yu and Gao, Peng and Yue, Xiangyu},
|
93 |
+
journal={arXiv preprint arXiv:xxxx},
|
94 |
+
year={2023}
|
95 |
+
}
|
96 |
+
```
|
97 |
+
|
98 |
+
## Acknowledgement
|
99 |
+
|
100 |
+
[LLaMA](https://github.com/facebookresearch/llama), [LLaMA-Adapter](https://github.com/OpenGVLab/LLaMA-Adapter), [LLaMA2-Accessory](https://github.com/Alpha-VLLM/LLaMA2-Accessory), [Meta-Transformer](https://github.com/invictus717/MetaTransformer), [ChatBridge](https://github.com/joez17/ChatBridge)
|
app.py
CHANGED
@@ -1,277 +1,272 @@
|
|
1 |
-
import json
|
2 |
-
import os
|
3 |
-
import glob
|
4 |
import sys
|
5 |
-
import
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
8 |
|
9 |
-
from huggingface_hub import hf_hub_download
|
10 |
-
from PIL import Image
|
11 |
-
import gradio as gr
|
12 |
import torch
|
13 |
-
|
14 |
-
|
15 |
-
from llama import LLaMA, ModelArgs, Tokenizer, Transformer, VisionModel
|
16 |
-
|
17 |
-
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
18 |
-
|
19 |
-
PROMPT_DICT = {
|
20 |
-
"prompt_input": (
|
21 |
-
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
22 |
-
"Write a response that appropriately completes the request.\n\n"
|
23 |
-
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
|
24 |
-
),
|
25 |
-
"prompt_no_input": (
|
26 |
-
"Below is an instruction that describes a task. "
|
27 |
-
"Write a response that appropriately completes the request.\n\n"
|
28 |
-
"### Instruction:\n{instruction}\n\n### Response:"
|
29 |
-
),
|
30 |
-
}
|
31 |
-
|
32 |
-
|
33 |
-
def setup_model_parallel() -> Tuple[int, int]:
|
34 |
-
os.environ['RANK'] = '0'
|
35 |
-
os.environ['WORLD_SIZE'] = '1'
|
36 |
-
os.environ['MP'] = '1'
|
37 |
-
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
38 |
-
os.environ['MASTER_PORT'] = '2223'
|
39 |
-
local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
40 |
-
world_size = int(os.environ.get("WORLD_SIZE", -1))
|
41 |
-
|
42 |
-
torch.distributed.init_process_group("nccl")
|
43 |
-
initialize_model_parallel(world_size)
|
44 |
-
torch.cuda.set_device(local_rank)
|
45 |
-
|
46 |
-
# seed must be the same in all processes
|
47 |
-
torch.manual_seed(1)
|
48 |
-
return local_rank, world_size
|
49 |
-
|
50 |
-
|
51 |
-
def load(
|
52 |
-
ckpt0_path: str,
|
53 |
-
ckpt1_path: str,
|
54 |
-
param_path: str,
|
55 |
-
tokenizer_path: str,
|
56 |
-
instruct_adapter_path: str,
|
57 |
-
caption_adapter_path: str,
|
58 |
-
local_rank: int,
|
59 |
-
world_size: int,
|
60 |
-
max_seq_len: int,
|
61 |
-
max_batch_size: int,
|
62 |
-
) -> LLaMA:
|
63 |
-
start_time = time.time()
|
64 |
-
print("Loading")
|
65 |
-
instruct_adapter_checkpoint = torch.load(
|
66 |
-
instruct_adapter_path, map_location="cpu")
|
67 |
-
caption_adapter_checkpoint = torch.load(
|
68 |
-
caption_adapter_path, map_location="cpu")
|
69 |
-
with open(param_path, "r") as f:
|
70 |
-
params = json.loads(f.read())
|
71 |
-
|
72 |
-
model_args: ModelArgs = ModelArgs(
|
73 |
-
max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
|
74 |
-
)
|
75 |
-
model_args.adapter_layer = int(
|
76 |
-
instruct_adapter_checkpoint['adapter_query.weight'].shape[0] / model_args.adapter_len)
|
77 |
-
model_args.cap_adapter_layer = int(
|
78 |
-
caption_adapter_checkpoint['cap_adapter_query.weight'].shape[0] / model_args.cap_adapter_len)
|
79 |
-
|
80 |
-
tokenizer = Tokenizer(model_path=tokenizer_path)
|
81 |
-
model_args.vocab_size = tokenizer.n_words
|
82 |
-
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
83 |
-
model = Transformer(model_args)
|
84 |
-
|
85 |
-
# To reduce memory usuage
|
86 |
-
ckpt0 = torch.load(ckpt0_path, map_location='cuda')
|
87 |
-
model.load_state_dict(ckpt0, strict=False)
|
88 |
-
del ckpt0
|
89 |
-
torch.cuda.empty_cache()
|
90 |
-
|
91 |
-
ckpt1 = torch.load(ckpt1_path, map_location='cuda')
|
92 |
-
model.load_state_dict(ckpt1, strict=False)
|
93 |
-
del ckpt1
|
94 |
-
torch.cuda.empty_cache()
|
95 |
-
|
96 |
-
vision_model = VisionModel(model_args)
|
97 |
-
|
98 |
-
torch.set_default_tensor_type(torch.FloatTensor)
|
99 |
-
model.load_state_dict(instruct_adapter_checkpoint, strict=False)
|
100 |
-
model.load_state_dict(caption_adapter_checkpoint, strict=False)
|
101 |
-
vision_model.load_state_dict(caption_adapter_checkpoint, strict=False)
|
102 |
-
|
103 |
-
generator = LLaMA(model, tokenizer, vision_model)
|
104 |
-
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
105 |
-
return generator
|
106 |
-
|
107 |
-
|
108 |
-
def instruct_generate(
|
109 |
-
instruct: str,
|
110 |
-
input: str = 'none',
|
111 |
-
max_gen_len=512,
|
112 |
-
temperature: float = 0.1,
|
113 |
-
top_p: float = 0.75,
|
114 |
-
):
|
115 |
-
if input == 'none':
|
116 |
-
prompt = PROMPT_DICT['prompt_no_input'].format_map(
|
117 |
-
{'instruction': instruct, 'input': ''})
|
118 |
-
else:
|
119 |
-
prompt = PROMPT_DICT['prompt_input'].format_map(
|
120 |
-
{'instruction': instruct, 'input': input})
|
121 |
-
|
122 |
-
results = generator.generate(
|
123 |
-
[prompt], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p
|
124 |
-
)
|
125 |
-
result = results[0].strip()
|
126 |
-
print(result)
|
127 |
-
return result
|
128 |
-
|
129 |
-
|
130 |
-
def caption_generate(
|
131 |
-
img: str,
|
132 |
-
max_gen_len=512,
|
133 |
-
temperature: float = 0.1,
|
134 |
-
top_p: float = 0.75,
|
135 |
-
):
|
136 |
-
imgs = [Image.open(img).convert('RGB')]
|
137 |
-
prompts = ["Generate caption of this image :",] * len(imgs)
|
138 |
-
|
139 |
-
results = generator.generate(
|
140 |
-
prompts, imgs=imgs, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p
|
141 |
-
)
|
142 |
-
result = results[0].strip()
|
143 |
-
print(result)
|
144 |
-
return result
|
145 |
-
|
146 |
-
|
147 |
-
def download_llama_adapter(instruct_adapter_path, caption_adapter_path):
|
148 |
-
if not os.path.exists(instruct_adapter_path):
|
149 |
-
os.system(
|
150 |
-
f"wget -q -O {instruct_adapter_path} https://github.com/ZrrSkywalker/LLaMA-Adapter/releases/download/v.1.0.0/llama_adapter_len10_layer30_release.pth")
|
151 |
-
|
152 |
-
if not os.path.exists(caption_adapter_path):
|
153 |
-
os.system(
|
154 |
-
f"wget -q -O {caption_adapter_path} https://github.com/ZrrSkywalker/LLaMA-Adapter/releases/download/v.1.0.0/llama_adapter_len10_layer30_caption_vit_l.pth")
|
155 |
-
|
156 |
-
|
157 |
-
# ckpt_path = "/data1/llma/7B/consolidated.00.pth"
|
158 |
-
# param_path = "/data1/llma/7B/params.json"
|
159 |
-
# tokenizer_path = "/data1/llma/tokenizer.model"
|
160 |
-
ckpt0_path = hf_hub_download(
|
161 |
-
repo_id="csuhan/llama_storage", filename="consolidated.00_part0.pth")
|
162 |
-
ckpt1_path = hf_hub_download(
|
163 |
-
repo_id="csuhan/llama_storage", filename="consolidated.00_part1.pth")
|
164 |
-
param_path = hf_hub_download(
|
165 |
-
repo_id="nyanko7/LLaMA-7B", filename="params.json")
|
166 |
-
tokenizer_path = hf_hub_download(
|
167 |
-
repo_id="nyanko7/LLaMA-7B", filename="tokenizer.model")
|
168 |
-
instruct_adapter_path = "llama_adapter_len10_layer30_release.pth"
|
169 |
-
caption_adapter_path = "llama_adapter_len10_layer30_caption_vit_l.pth"
|
170 |
-
max_seq_len = 512
|
171 |
-
max_batch_size = 1
|
172 |
-
|
173 |
-
# download models
|
174 |
-
# download_llama_adapter(instruct_adapter_path, caption_adapter_path)
|
175 |
-
|
176 |
-
local_rank, world_size = setup_model_parallel()
|
177 |
-
if local_rank > 0:
|
178 |
-
sys.stdout = open(os.devnull, "w")
|
179 |
-
|
180 |
-
generator = load(
|
181 |
-
ckpt0_path, ckpt1_path, param_path, tokenizer_path, instruct_adapter_path, caption_adapter_path, local_rank, world_size, max_seq_len, max_batch_size
|
182 |
-
)
|
183 |
-
|
184 |
-
|
185 |
-
def create_instruct_demo():
|
186 |
-
with gr.Blocks() as instruct_demo:
|
187 |
-
with gr.Row():
|
188 |
-
with gr.Column():
|
189 |
-
instruction = gr.Textbox(lines=2, label="Instruction")
|
190 |
-
input = gr.Textbox(
|
191 |
-
lines=2, label="Context input", placeholder='none')
|
192 |
-
max_len = gr.Slider(minimum=1, maximum=512,
|
193 |
-
value=128, label="Max length")
|
194 |
-
with gr.Accordion(label='Advanced options', open=False):
|
195 |
-
temp = gr.Slider(minimum=0, maximum=1,
|
196 |
-
value=0.1, label="Temperature")
|
197 |
-
top_p = gr.Slider(minimum=0, maximum=1,
|
198 |
-
value=0.75, label="Top p")
|
199 |
-
|
200 |
-
run_botton = gr.Button("Run")
|
201 |
-
|
202 |
-
with gr.Column():
|
203 |
-
outputs = gr.Textbox(lines=10, label="Output")
|
204 |
-
|
205 |
-
inputs = [instruction, input, max_len, temp, top_p]
|
206 |
-
|
207 |
-
examples = [
|
208 |
-
"Tell me about alpacas.",
|
209 |
-
"Write a Python program that prints the first 10 Fibonacci numbers.",
|
210 |
-
"Write a conversation between the sun and pluto.",
|
211 |
-
"Write a theory to explain why cat never existed",
|
212 |
-
]
|
213 |
-
examples = [
|
214 |
-
[x, "none", 128, 0.1, 0.75]
|
215 |
-
for x in examples]
|
216 |
-
|
217 |
-
gr.Examples(
|
218 |
-
examples=examples,
|
219 |
-
inputs=inputs,
|
220 |
-
outputs=outputs,
|
221 |
-
fn=instruct_generate,
|
222 |
-
cache_examples=os.getenv('SYSTEM') == 'spaces'
|
223 |
-
)
|
224 |
-
run_botton.click(fn=instruct_generate, inputs=inputs, outputs=outputs)
|
225 |
-
return instruct_demo
|
226 |
|
|
|
227 |
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
value=0.75, label="Top p")
|
240 |
-
|
241 |
-
run_botton = gr.Button("Run")
|
242 |
-
|
243 |
-
with gr.Column():
|
244 |
-
outputs = gr.Textbox(lines=10, label="Output")
|
245 |
-
|
246 |
-
inputs = [img, max_len, temp, top_p]
|
247 |
-
|
248 |
-
examples = glob.glob("caption_demo/*.jpg")
|
249 |
-
examples = [
|
250 |
-
[x, 64, 0.1, 0.75]
|
251 |
-
for x in examples]
|
252 |
-
|
253 |
-
gr.Examples(
|
254 |
-
examples=examples,
|
255 |
-
inputs=inputs,
|
256 |
-
outputs=outputs,
|
257 |
-
fn=caption_generate,
|
258 |
-
cache_examples=os.getenv('SYSTEM') == 'spaces'
|
259 |
-
)
|
260 |
-
run_botton.click(fn=caption_generate, inputs=inputs, outputs=outputs)
|
261 |
-
return instruct_demo
|
262 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
|
270 |
-
with gr.
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
276 |
|
277 |
-
|
|
|
|
|
|
|
|
|
1 |
import sys
|
2 |
+
import os
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import multiprocessing as mp
|
6 |
+
import numpy as np
|
7 |
+
from typing import List, Optional
|
8 |
|
|
|
|
|
|
|
9 |
import torch
|
10 |
+
import torch.distributed as dist
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
+
from fairscale.nn.model_parallel import initialize as fs_init
|
13 |
|
14 |
+
import gradio as gr
|
15 |
+
from util.misc import setup_for_distributed
|
16 |
+
from util.misc import default_tensor_type
|
17 |
+
from model.meta import MetaModel
|
18 |
+
from data.conversation_lib import conv_templates, SeparatorStyle
|
19 |
+
from PIL import Image
|
20 |
+
import torchvision.transforms as transforms
|
21 |
+
from data.fintune_dataset import make_audio_features
|
22 |
+
from data import video_utils
|
23 |
+
from dataclasses import dataclass
|
24 |
+
from huggingface_hub import hf_hub_download
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
+
T_random_resized_crop = transforms.Compose([
|
27 |
+
transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=3,
|
28 |
+
antialias=None), # 3 is bicubic
|
29 |
+
transforms.ToTensor(),
|
30 |
+
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
|
31 |
+
|
32 |
+
|
33 |
+
def load_audio(audio_path):
|
34 |
+
fbank = make_audio_features(audio_path, mel_bins=128)
|
35 |
+
fbank = fbank.transpose(0, 1)[None] #[1, 128, 1024]
|
36 |
+
return fbank
|
37 |
+
|
38 |
+
def load_video(video_path):
|
39 |
+
video_feats = video_utils.load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5)
|
40 |
+
return video_feats[:, :, 0]
|
41 |
+
|
42 |
+
|
43 |
+
def model_worker(
|
44 |
+
rank: int, args: argparse.Namespace, barrier: mp.Barrier,
|
45 |
+
request_queue: mp.Queue, response_queue: Optional[mp.Queue] = None,
|
46 |
+
) -> None:
|
47 |
+
"""
|
48 |
+
The worker function that manipulates the GPU to run the inference.
|
49 |
+
Exact n_gpu workers are started, with each one operating on a separate GPU.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
rank (int): Distributed rank of the worker.
|
53 |
+
args (argparse.Namespace): All command line arguments.
|
54 |
+
barrier (multiprocessing.Barrier): A barrier used to delay the start
|
55 |
+
of Web UI to be after the start of the model.
|
56 |
+
"""
|
57 |
+
|
58 |
+
world_size = len(args.gpu_ids)
|
59 |
+
gpu_id = args.gpu_ids[rank]
|
60 |
+
dist.init_process_group(
|
61 |
+
backend="nccl", rank=rank, world_size=world_size,
|
62 |
+
init_method=f"tcp://{args.master_addr}:{args.master_port}",
|
63 |
+
)
|
64 |
+
print(f"| distributed init on worker {rank}/{world_size}. "
|
65 |
+
f"using gpu: {gpu_id}")
|
66 |
+
fs_init.initialize_model_parallel(world_size)
|
67 |
+
torch.cuda.set_device(gpu_id)
|
68 |
|
69 |
+
torch.manual_seed(1)
|
70 |
+
np.random.seed(1)
|
71 |
+
|
72 |
+
# set the print behavior.
|
73 |
+
setup_for_distributed(rank == 0)
|
74 |
+
|
75 |
+
target_dtype = {
|
76 |
+
"bf16": torch.bfloat16,
|
77 |
+
"fp16": torch.float16
|
78 |
+
}[args.dtype]
|
79 |
+
with default_tensor_type(dtype=target_dtype, device="cuda"):
|
80 |
+
model = MetaModel(args.llama_type, args.llama_config, tokenizer_path=args.tokenizer_path)
|
81 |
+
print("Loading pretrained weights ...")
|
82 |
+
checkpoint = torch.load(args.pretrained_path, map_location='cpu')
|
83 |
+
msg = model.load_state_dict(checkpoint, strict=False)
|
84 |
+
print("load result:\n", msg)
|
85 |
+
model.cuda()
|
86 |
+
model.eval()
|
87 |
+
print(f"Model = {str(model)}")
|
88 |
+
|
89 |
+
barrier.wait()
|
90 |
+
|
91 |
+
while True:
|
92 |
+
img_path, audio_path, video_path, chatbot, max_gen_len, temperature, top_p, modality = request_queue.get()
|
93 |
+
if 'image' in modality and img_path is not None:
|
94 |
+
image = Image.open(img_path).convert('RGB')
|
95 |
+
inputs = T_random_resized_crop(image)
|
96 |
+
elif 'video' in modality and video_path is not None:
|
97 |
+
inputs = load_video(video_path)
|
98 |
+
elif 'audio' in modality and audio_path is not None:
|
99 |
+
inputs = load_audio(audio_path)
|
100 |
+
else:
|
101 |
+
inputs = None
|
102 |
+
|
103 |
+
if inputs is not None:
|
104 |
+
inputs = inputs[None].cuda().to(target_dtype)
|
105 |
+
|
106 |
+
conv = conv_templates["v1"].copy()
|
107 |
+
for user, bot in chatbot:
|
108 |
+
conv.append_message(conv.roles[0], user)
|
109 |
+
conv.append_message(conv.roles[1], bot)
|
110 |
+
|
111 |
+
with torch.cuda.amp.autocast(dtype=target_dtype):
|
112 |
+
print(conv.get_prompt())
|
113 |
+
for stream_response in model.stream_generate(
|
114 |
+
conv.get_prompt(), inputs,
|
115 |
+
max_gen_len=max_gen_len, temperature=temperature, top_p=top_p,
|
116 |
+
modal = modality
|
117 |
+
):
|
118 |
+
conv_sep = (
|
119 |
+
conv.sep
|
120 |
+
if conv.sep_style == SeparatorStyle.SINGLE
|
121 |
+
else conv.sep2
|
122 |
+
)
|
123 |
+
end_pos = stream_response["text"].find(conv_sep)
|
124 |
+
if end_pos != -1:
|
125 |
+
stream_response["text"] = (
|
126 |
+
stream_response['text'][:end_pos].rstrip() + "\n"
|
127 |
+
)
|
128 |
+
stream_response["end_of_content"] = True
|
129 |
+
|
130 |
+
# keep a few characters if not end_of_content to avoid sending
|
131 |
+
# part of conv_sep before all of it is generated.
|
132 |
+
if not stream_response["end_of_content"]:
|
133 |
+
if len(stream_response["text"]) < len(conv_sep):
|
134 |
+
continue
|
135 |
+
stream_response["text"] = (
|
136 |
+
stream_response["text"][:-len(conv_sep)]
|
137 |
+
)
|
138 |
+
|
139 |
+
if response_queue is not None:
|
140 |
+
response_queue.put(stream_response)
|
141 |
+
|
142 |
+
if stream_response["end_of_content"]:
|
143 |
+
break
|
144 |
+
|
145 |
+
|
146 |
+
def gradio_worker(
|
147 |
+
request_queues: List[mp.Queue], response_queue: mp.Queue,
|
148 |
+
args: argparse.Namespace, barrier: mp.Barrier,
|
149 |
+
) -> None:
|
150 |
+
"""
|
151 |
+
The gradio worker is responsible for displaying the WebUI and relay the
|
152 |
+
requests to model workers. It should be launched only once.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
request_queues (List[mp.Queue]): A list of request queues (one for
|
156 |
+
each model worker).
|
157 |
+
args (argparse.Namespace): All command line arguments.
|
158 |
+
barrier (multiprocessing.Barrier): A barrier used to delay the start
|
159 |
+
of Web UI to be after the start of the model.
|
160 |
+
"""
|
161 |
+
|
162 |
+
def show_user_input(msg, chatbot):
|
163 |
+
return "", chatbot + [[msg, None]]
|
164 |
+
|
165 |
+
def stream_model_output(img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality):
|
166 |
+
for queue in request_queues:
|
167 |
+
queue.put((img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality))
|
168 |
+
while True:
|
169 |
+
content_piece = response_queue.get()
|
170 |
+
chatbot[-1][1] = content_piece["text"]
|
171 |
+
yield chatbot
|
172 |
+
if content_piece["end_of_content"]:
|
173 |
+
break
|
174 |
+
|
175 |
+
def undo(chatbot):
|
176 |
+
if len(chatbot) > 0:
|
177 |
+
chatbot = chatbot[:-1]
|
178 |
+
return chatbot
|
179 |
+
|
180 |
+
def clear():
|
181 |
+
chatbot = []
|
182 |
+
msg = ""
|
183 |
+
return chatbot, msg
|
184 |
+
|
185 |
+
CSS ="""
|
186 |
+
.contain { display: flex; flex-direction: column; }
|
187 |
+
#component-0 { height: 100%; }
|
188 |
+
#chatbot { flex-grow: 1; overflow: auto;}
|
189 |
+
"""
|
190 |
+
with gr.Blocks(css=CSS) as demo:
|
191 |
+
gr.Markdown("## OneLLM: One Framework to Align All Modalities with Language")
|
192 |
+
with gr.Row(equal_height=True):
|
193 |
+
with gr.Column(scale=1):
|
194 |
+
img_path = gr.Image(label='Image Input', type='filepath')
|
195 |
+
video_path = gr.Video(label='Video Input')
|
196 |
+
audio_path = gr.Audio(label='Audio Input', type='filepath', sources=['upload'])
|
197 |
+
modality = gr.Radio(choices=['image', 'audio', 'video'], value='image', interactive=True, label='Input Modalities')
|
198 |
+
|
199 |
+
with gr.Column(scale=2):
|
200 |
+
chatbot = gr.Chatbot(elem_id="chatbot")
|
201 |
+
msg = gr.Textbox()
|
202 |
|
203 |
+
with gr.Row():
|
204 |
+
submit_button = gr.Button("Submit", variant="primary")
|
205 |
+
undo_button = gr.Button("Undo")
|
206 |
+
clear_button = gr.ClearButton([chatbot, msg, img_path, audio_path, video_path, modality])
|
207 |
+
with gr.Row():
|
208 |
+
max_gen_len = gr.Slider(
|
209 |
+
minimum=1, maximum=args.model_max_seq_len // 2,
|
210 |
+
value=args.model_max_seq_len // 2, interactive=True,
|
211 |
+
label="Single-turn max response length",
|
212 |
+
)
|
213 |
+
gen_t = gr.Slider(
|
214 |
+
minimum=0, maximum=1, value=0.1, interactive=True,
|
215 |
+
label="Temperature",
|
216 |
+
)
|
217 |
+
top_p = gr.Slider(
|
218 |
+
minimum=0, maximum=1, value=0.75, interactive=True,
|
219 |
+
label="Top-p",
|
220 |
+
)
|
221 |
+
msg.submit(
|
222 |
+
show_user_input, [msg, chatbot], [msg, chatbot],
|
223 |
+
).then(
|
224 |
+
stream_model_output, [img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot,
|
225 |
+
)
|
226 |
+
submit_button.click(
|
227 |
+
show_user_input, [msg, chatbot], [msg, chatbot],
|
228 |
+
).then(
|
229 |
+
stream_model_output, [img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot,
|
230 |
+
)
|
231 |
+
undo_button.click(undo, chatbot, chatbot)
|
232 |
+
# img_path.change(clear, [], [chatbot, msg])
|
233 |
+
barrier.wait()
|
234 |
+
demo.queue(api_open=True).launch(share=True, max_threads=1)
|
235 |
+
|
236 |
+
|
237 |
+
@dataclass
|
238 |
+
class DemoConfig:
|
239 |
+
gpu_ids = [0]
|
240 |
+
tokenizer_path = "config/llama2/tokenizer.model"
|
241 |
+
llama_type = "onellm"
|
242 |
+
llama_config = "config/llama2/7B.json"
|
243 |
+
model_max_seq_len = 2048
|
244 |
+
# pretrained_path = "weights/7B_2048/consolidated.00-of-01.pth"
|
245 |
+
pretrained_path = hf_hub_download(repo_id="csuhan/OneLLM-7B", filename="consolidated.00-of-01.pth")
|
246 |
+
master_port = 23861
|
247 |
+
master_addr = "127.0.0.1"
|
248 |
+
dtype = "fp16"
|
249 |
+
|
250 |
+
if __name__ == "__main__":
|
251 |
+
args = DemoConfig()
|
252 |
+
# using the default "fork" method messes up some imported libs (e.g.,
|
253 |
+
# pandas)
|
254 |
+
mp.set_start_method("spawn")
|
255 |
+
|
256 |
+
# setup the queues and start the model workers
|
257 |
+
request_queues = []
|
258 |
+
response_queue = mp.Queue()
|
259 |
+
worker_processes = []
|
260 |
+
barrier = mp.Barrier(len(args.gpu_ids) + 1)
|
261 |
+
for rank, gpu_id in enumerate(args.gpu_ids):
|
262 |
+
request_queue = mp.Queue()
|
263 |
+
rank_response_queue = response_queue if rank == 0 else None
|
264 |
+
process = mp.Process(
|
265 |
+
target=model_worker,
|
266 |
+
args=(rank, args, barrier, request_queue, rank_response_queue),
|
267 |
+
)
|
268 |
+
process.start()
|
269 |
+
worker_processes.append(process)
|
270 |
+
request_queues.append(request_queue)
|
271 |
|
272 |
+
gradio_worker(request_queues, response_queue, args, barrier)
|
config/llama2/7B.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-05, "vocab_size": -1}
|
config/llama2/tokenizer.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
|
3 |
+
size 499723
|
data/__pycache__/conversation_lib.cpython-310.pyc
ADDED
Binary file (9.14 kB). View file
|
|
data/__pycache__/conversation_lib.cpython-39.pyc
ADDED
Binary file (9.15 kB). View file
|
|
data/__pycache__/fintune_dataset.cpython-310.pyc
ADDED
Binary file (14.2 kB). View file
|
|
data/__pycache__/fintune_dataset.cpython-39.pyc
ADDED
Binary file (14.2 kB). View file
|
|
data/__pycache__/imu_utils.cpython-310.pyc
ADDED
Binary file (6.71 kB). View file
|
|
data/__pycache__/imu_utils.cpython-39.pyc
ADDED
Binary file (6.71 kB). View file
|
|
data/__pycache__/video_utils.cpython-310.pyc
ADDED
Binary file (6.53 kB). View file
|
|
data/__pycache__/video_utils.cpython-39.pyc
ADDED
Binary file (6.51 kB). View file
|
|
data/conversation_lib.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from enum import auto, Enum
|
3 |
+
from typing import List, Tuple
|
4 |
+
|
5 |
+
|
6 |
+
class SeparatorStyle(Enum):
|
7 |
+
"""Different separator style."""
|
8 |
+
SINGLE = auto()
|
9 |
+
TWO = auto()
|
10 |
+
MPT = auto()
|
11 |
+
|
12 |
+
|
13 |
+
@dataclasses.dataclass
|
14 |
+
class Conversation:
|
15 |
+
"""A class that keeps all conversation history."""
|
16 |
+
system: str
|
17 |
+
roles: List[str]
|
18 |
+
messages: List[List[str]]
|
19 |
+
offset: int
|
20 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
21 |
+
sep: str = "###"
|
22 |
+
sep2: str = None
|
23 |
+
version: str = "Unknown"
|
24 |
+
|
25 |
+
skip_next: bool = False
|
26 |
+
|
27 |
+
def get_prompt(self):
|
28 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
29 |
+
ret = self.system + '\n\n' + self.sep
|
30 |
+
for role, message in self.messages:
|
31 |
+
if message:
|
32 |
+
if type(message) is tuple:
|
33 |
+
message, _, _ = message
|
34 |
+
ret += role + ": " + message + '\n' + self.sep
|
35 |
+
else:
|
36 |
+
ret += role + ":"
|
37 |
+
return ret
|
38 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
39 |
+
seps = [self.sep, self.sep2]
|
40 |
+
ret = self.system + seps[0]
|
41 |
+
for i, (role, message) in enumerate(self.messages):
|
42 |
+
if message:
|
43 |
+
if type(message) is tuple:
|
44 |
+
message, _, _ = message
|
45 |
+
ret += role + ": " + message + seps[i % 2]
|
46 |
+
else:
|
47 |
+
ret += role + ":"
|
48 |
+
return ret
|
49 |
+
if self.sep_style == SeparatorStyle.MPT:
|
50 |
+
ret = self.system + self.sep
|
51 |
+
for role, message in self.messages:
|
52 |
+
if message:
|
53 |
+
if type(message) is tuple:
|
54 |
+
message, _, _ = message
|
55 |
+
ret += role + message + self.sep
|
56 |
+
else:
|
57 |
+
ret += role
|
58 |
+
return ret
|
59 |
+
else:
|
60 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
61 |
+
|
62 |
+
def append_message(self, role, message):
|
63 |
+
self.messages.append([role, message])
|
64 |
+
|
65 |
+
def get_images(self, return_pil=False):
|
66 |
+
images = []
|
67 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
68 |
+
if i % 2 == 0:
|
69 |
+
if type(msg) is tuple:
|
70 |
+
import base64
|
71 |
+
from io import BytesIO
|
72 |
+
from PIL import Image
|
73 |
+
msg, image, image_process_mode = msg
|
74 |
+
if image_process_mode == "Pad":
|
75 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
76 |
+
width, height = pil_img.size
|
77 |
+
if width == height:
|
78 |
+
return pil_img
|
79 |
+
elif width > height:
|
80 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
81 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
82 |
+
return result
|
83 |
+
else:
|
84 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
85 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
86 |
+
return result
|
87 |
+
|
88 |
+
image = expand2square(image)
|
89 |
+
elif image_process_mode == "Crop":
|
90 |
+
pass
|
91 |
+
elif image_process_mode == "Resize":
|
92 |
+
image = image.resize((224, 224))
|
93 |
+
else:
|
94 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
95 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
96 |
+
aspect_ratio = max_hw / min_hw
|
97 |
+
max_len, min_len = 800, 400
|
98 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
99 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
100 |
+
W, H = image.size
|
101 |
+
if H > W:
|
102 |
+
H, W = longest_edge, shortest_edge
|
103 |
+
else:
|
104 |
+
H, W = shortest_edge, longest_edge
|
105 |
+
image = image.resize((W, H))
|
106 |
+
if return_pil:
|
107 |
+
images.append(image)
|
108 |
+
else:
|
109 |
+
buffered = BytesIO()
|
110 |
+
image.save(buffered, format="JPEG")
|
111 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
112 |
+
images.append(img_b64_str)
|
113 |
+
return images
|
114 |
+
|
115 |
+
def to_gradio_chatbot(self):
|
116 |
+
ret = []
|
117 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
118 |
+
if i % 2 == 0:
|
119 |
+
if type(msg) is tuple:
|
120 |
+
import base64
|
121 |
+
from io import BytesIO
|
122 |
+
msg, image, image_process_mode = msg
|
123 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
124 |
+
aspect_ratio = max_hw / min_hw
|
125 |
+
max_len, min_len = 800, 400
|
126 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
127 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
128 |
+
W, H = image.size
|
129 |
+
if H > W:
|
130 |
+
H, W = longest_edge, shortest_edge
|
131 |
+
else:
|
132 |
+
H, W = shortest_edge, longest_edge
|
133 |
+
image = image.resize((W, H))
|
134 |
+
# image = image.resize((224, 224))
|
135 |
+
buffered = BytesIO()
|
136 |
+
image.save(buffered, format="JPEG")
|
137 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
138 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
139 |
+
msg = msg.replace('<image>', img_str)
|
140 |
+
ret.append([msg, None])
|
141 |
+
else:
|
142 |
+
ret[-1][-1] = msg
|
143 |
+
return ret
|
144 |
+
|
145 |
+
def copy(self):
|
146 |
+
return Conversation(
|
147 |
+
system=self.system,
|
148 |
+
roles=self.roles,
|
149 |
+
messages=[[x, y] for x, y in self.messages],
|
150 |
+
offset=self.offset,
|
151 |
+
sep_style=self.sep_style,
|
152 |
+
sep=self.sep,
|
153 |
+
sep2=self.sep2)
|
154 |
+
|
155 |
+
def dict(self):
|
156 |
+
if len(self.get_images()) > 0:
|
157 |
+
return {
|
158 |
+
"system": self.system,
|
159 |
+
"roles": self.roles,
|
160 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
161 |
+
"offset": self.offset,
|
162 |
+
"sep": self.sep,
|
163 |
+
"sep2": self.sep2,
|
164 |
+
}
|
165 |
+
return {
|
166 |
+
"system": self.system,
|
167 |
+
"roles": self.roles,
|
168 |
+
"messages": self.messages,
|
169 |
+
"offset": self.offset,
|
170 |
+
"sep": self.sep,
|
171 |
+
"sep2": self.sep2,
|
172 |
+
}
|
173 |
+
|
174 |
+
|
175 |
+
conv_v1 = Conversation(
|
176 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
177 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
178 |
+
roles=("Human", "Assistant"),
|
179 |
+
messages=(
|
180 |
+
("Human", "Give three tips for staying healthy."),
|
181 |
+
("Assistant",
|
182 |
+
"Sure, here are three tips for staying healthy:\n"
|
183 |
+
"1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
|
184 |
+
"It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
|
185 |
+
"and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or "
|
186 |
+
"75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening "
|
187 |
+
"activities at least two days per week.\n"
|
188 |
+
"2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, "
|
189 |
+
"vegetables, whole grains, lean proteins, and healthy fats can help support "
|
190 |
+
"your overall health. Try to limit your intake of processed and high-sugar foods, "
|
191 |
+
"and aim to drink plenty of water throughout the day.\n"
|
192 |
+
"3. Get enough sleep: Getting enough quality sleep is essential for your physical "
|
193 |
+
"and mental health. Adults should aim for seven to nine hours of sleep per night. "
|
194 |
+
"Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
|
195 |
+
"help improve the quality of your sleep.")
|
196 |
+
),
|
197 |
+
offset=2,
|
198 |
+
sep_style=SeparatorStyle.SINGLE,
|
199 |
+
sep="###",
|
200 |
+
)
|
201 |
+
|
202 |
+
conv_v1_2 = Conversation(
|
203 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
204 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
205 |
+
roles=("Human", "Assistant"),
|
206 |
+
messages=(),
|
207 |
+
|
208 |
+
# (
|
209 |
+
# ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
|
210 |
+
# ("Assistant",
|
211 |
+
# "Renewable energy sources are those that can be replenished naturally in a relatively "
|
212 |
+
# "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
213 |
+
# "Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
214 |
+
# "depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
215 |
+
# "renewable and non-renewable energy sources:\n"
|
216 |
+
# "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
217 |
+
# "energy sources are finite and will eventually run out.\n"
|
218 |
+
# "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
219 |
+
# "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
220 |
+
# "and other negative effects.\n"
|
221 |
+
# "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
222 |
+
# "have lower operational costs than non-renewable sources.\n"
|
223 |
+
# "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
224 |
+
# "locations than non-renewable sources.\n"
|
225 |
+
# "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
226 |
+
# "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
227 |
+
# "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
228 |
+
# "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
|
229 |
+
# )
|
230 |
+
offset = 2,
|
231 |
+
sep_style = SeparatorStyle.SINGLE,
|
232 |
+
sep = "###",
|
233 |
+
)
|
234 |
+
|
235 |
+
conv_vicuna_v1_1 = Conversation(
|
236 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
237 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
238 |
+
roles=("USER", "ASSISTANT"),
|
239 |
+
version="v1",
|
240 |
+
messages=(),
|
241 |
+
offset=0,
|
242 |
+
sep_style=SeparatorStyle.TWO,
|
243 |
+
sep=" ",
|
244 |
+
sep2="</s>",
|
245 |
+
)
|
246 |
+
|
247 |
+
conv_mpt = Conversation(
|
248 |
+
system="""<|im_start|>system
|
249 |
+
- You are a helpful language and vision assistant.
|
250 |
+
- You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
|
251 |
+
- You should follow the instructions carefully and explain your answers in detail.""",
|
252 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
253 |
+
version="mpt",
|
254 |
+
messages=(),
|
255 |
+
offset=0,
|
256 |
+
sep_style=SeparatorStyle.MPT,
|
257 |
+
sep="<|im_end|>",
|
258 |
+
)
|
259 |
+
|
260 |
+
conv_mpt_text = Conversation(
|
261 |
+
system="""<|im_start|>system
|
262 |
+
- You are a helpful assistant chatbot trained by MosaicML.
|
263 |
+
- You answer questions.
|
264 |
+
- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
265 |
+
- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""",
|
266 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
267 |
+
version="mpt",
|
268 |
+
messages=(),
|
269 |
+
offset=0,
|
270 |
+
sep_style=SeparatorStyle.MPT,
|
271 |
+
sep="<|im_end|>",
|
272 |
+
)
|
273 |
+
|
274 |
+
conv_bair_v1 = Conversation(
|
275 |
+
system="BEGINNING OF CONVERSATION:",
|
276 |
+
roles=("USER", "GPT"),
|
277 |
+
messages=(),
|
278 |
+
offset=0,
|
279 |
+
sep_style=SeparatorStyle.TWO,
|
280 |
+
sep=" ",
|
281 |
+
sep2="</s>",
|
282 |
+
)
|
283 |
+
|
284 |
+
simple_conv = Conversation(
|
285 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
286 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
287 |
+
roles=("Human", "Assistant"),
|
288 |
+
messages=(
|
289 |
+
("Human", "Hi!"),
|
290 |
+
("Assistant", "Hi there! How can I help you today?")
|
291 |
+
),
|
292 |
+
offset=2,
|
293 |
+
sep_style=SeparatorStyle.SINGLE,
|
294 |
+
sep="###",
|
295 |
+
)
|
296 |
+
|
297 |
+
simple_conv_multimodal = Conversation(
|
298 |
+
system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
|
299 |
+
"You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
300 |
+
"Follow the instructions carefully and explain your answers in detail.",
|
301 |
+
roles=("Human", "Assistant"),
|
302 |
+
messages=(
|
303 |
+
("Human", "Hi!"),
|
304 |
+
("Assistant", "Hi there! How can I help you today?\n")
|
305 |
+
),
|
306 |
+
offset=2,
|
307 |
+
sep_style=SeparatorStyle.SINGLE,
|
308 |
+
sep="###",
|
309 |
+
)
|
310 |
+
|
311 |
+
simple_conv_mpt_multimodal = Conversation(
|
312 |
+
system="""<|im_start|>system
|
313 |
+
- You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab.
|
314 |
+
- You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
|
315 |
+
- You should follow the instructions carefully and explain your answers in detail.""",
|
316 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
317 |
+
version="mpt",
|
318 |
+
messages=(),
|
319 |
+
offset=0,
|
320 |
+
sep_style=SeparatorStyle.MPT,
|
321 |
+
sep="<|im_end|>",
|
322 |
+
)
|
323 |
+
|
324 |
+
simple_conv_legacy = Conversation(
|
325 |
+
system="You are LLaVA, a large language model trained by UW Madison WAIV Lab."
|
326 |
+
"You are designed to assist human with a variety of tasks using natural language."
|
327 |
+
"Follow the instructions carefully.",
|
328 |
+
roles=("Human", "Assistant"),
|
329 |
+
messages=(
|
330 |
+
("Human", "Hi!\n\n### Response:"),
|
331 |
+
("Assistant", "Hi there! How can I help you today?\n")
|
332 |
+
),
|
333 |
+
offset=2,
|
334 |
+
sep_style=SeparatorStyle.SINGLE,
|
335 |
+
sep="###",
|
336 |
+
)
|
337 |
+
|
338 |
+
conv_llava_v1 = Conversation(
|
339 |
+
system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
|
340 |
+
"You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
341 |
+
"Follow the instructions carefully and explain your answers in detail.",
|
342 |
+
roles=("USER", "ASSISTANT"),
|
343 |
+
version="v1",
|
344 |
+
messages=(),
|
345 |
+
offset=0,
|
346 |
+
sep_style=SeparatorStyle.TWO,
|
347 |
+
sep=" ",
|
348 |
+
sep2="</s>",
|
349 |
+
)
|
350 |
+
|
351 |
+
default_conversation = conv_v1_2
|
352 |
+
conv_templates = {
|
353 |
+
"default": conv_v1_2,
|
354 |
+
"simple": simple_conv,
|
355 |
+
"simple_legacy": simple_conv_legacy,
|
356 |
+
"multimodal": simple_conv_multimodal,
|
357 |
+
"mpt_multimodal": simple_conv_mpt_multimodal,
|
358 |
+
"llava_v1": conv_llava_v1,
|
359 |
+
|
360 |
+
# fastchat
|
361 |
+
"v1": conv_v1_2,
|
362 |
+
"bair_v1": conv_bair_v1,
|
363 |
+
"vicuna_v1_1": conv_vicuna_v1_1,
|
364 |
+
"mpt": conv_mpt,
|
365 |
+
"mpt_text": conv_mpt_text,
|
366 |
+
}
|
367 |
+
|
368 |
+
if __name__ == "__main__":
|
369 |
+
print(default_conversation.get_prompt())
|
data/fintune_dataset.py
ADDED
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import yaml
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from PIL import Image
|
7 |
+
import json
|
8 |
+
from model.tokenizer import Tokenizer
|
9 |
+
import os
|
10 |
+
import torchvision.transforms as transforms
|
11 |
+
import random
|
12 |
+
import torchvision.transforms.functional as F
|
13 |
+
import torchaudio
|
14 |
+
from . import conversation_lib
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
from . import video_utils
|
18 |
+
from .imu_utils import get_imu_frames
|
19 |
+
|
20 |
+
|
21 |
+
IGNORE_INDEX = -100
|
22 |
+
|
23 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
24 |
+
try:
|
25 |
+
from torchvision.transforms import InterpolationMode
|
26 |
+
|
27 |
+
BICUBIC = InterpolationMode.BICUBIC
|
28 |
+
except ImportError:
|
29 |
+
BICUBIC = Image.BICUBIC
|
30 |
+
|
31 |
+
T_random_resized_crop = transforms.Compose([
|
32 |
+
transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=BICUBIC,
|
33 |
+
antialias=None), # 3 is bicubic
|
34 |
+
transforms.ToTensor(),
|
35 |
+
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
|
36 |
+
|
37 |
+
|
38 |
+
# image transform
|
39 |
+
transform_img_train = transforms.Compose([
|
40 |
+
transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(
|
41 |
+
0.75, 1.3333), interpolation=3, antialias=None), # 3 is bicubic
|
42 |
+
transforms.ToTensor(),
|
43 |
+
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
|
44 |
+
|
45 |
+
|
46 |
+
class PairRandomResizedCrop(transforms.RandomResizedCrop):
|
47 |
+
def forward(self, imgs):
|
48 |
+
i, j, h, w = self.get_params(imgs[0], self.scale, self.ratio)
|
49 |
+
return [F.resized_crop(img, i, j, h, w, self.size, self.interpolation, antialias=self.antialias) for img in imgs]
|
50 |
+
|
51 |
+
|
52 |
+
class PairToTensor(transforms.ToTensor):
|
53 |
+
def __call__(self, pics):
|
54 |
+
return [F.to_tensor(pic) for pic in pics]
|
55 |
+
|
56 |
+
|
57 |
+
class PairNormalize(transforms.Normalize):
|
58 |
+
def forward(self, tensors):
|
59 |
+
return [F.normalize(tensor, self.mean, self.std, self.inplace) for tensor in tensors]
|
60 |
+
|
61 |
+
|
62 |
+
transform_pairimg_train = transforms.Compose([
|
63 |
+
PairRandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(
|
64 |
+
0.75, 1.3333), interpolation=3, antialias=None), # 3 is bicubic
|
65 |
+
PairToTensor(),
|
66 |
+
PairNormalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
|
67 |
+
|
68 |
+
|
69 |
+
def pc_norm(pc):
|
70 |
+
""" pc: NxC, return NxC """
|
71 |
+
xyz = pc[:, :3]
|
72 |
+
other_feature = pc[:, 3:]
|
73 |
+
|
74 |
+
centroid = torch.mean(xyz, dim=0)
|
75 |
+
xyz = xyz - centroid
|
76 |
+
m = torch.max(torch.sqrt(torch.sum(xyz ** 2, dim=1)))
|
77 |
+
xyz = xyz / m
|
78 |
+
|
79 |
+
pc = torch.cat((xyz, other_feature), dim=1)
|
80 |
+
return pc
|
81 |
+
|
82 |
+
|
83 |
+
def make_audio_features(wav_name, mel_bins=128, target_length=1024, aug=False):
|
84 |
+
waveform, sr = torchaudio.load(wav_name)
|
85 |
+
# assert sr == 16000, 'input audio sampling rate must be 16kHz'
|
86 |
+
if sr != 16000:
|
87 |
+
trans = torchaudio.transforms.Resample(sr, 16000)
|
88 |
+
waveform = trans(waveform)
|
89 |
+
|
90 |
+
waveform = waveform - waveform.mean()
|
91 |
+
|
92 |
+
fbank = torchaudio.compliance.kaldi.fbank(
|
93 |
+
waveform, htk_compat=True, sample_frequency=16000, use_energy=False,
|
94 |
+
window_type='hanning', num_mel_bins=mel_bins, dither=0.0, frame_shift=10)
|
95 |
+
|
96 |
+
n_frames = fbank.shape[0]
|
97 |
+
|
98 |
+
p = target_length - n_frames
|
99 |
+
if p > 0:
|
100 |
+
m = torch.nn.ZeroPad2d((0, 0, 0, p))
|
101 |
+
fbank = m(fbank)
|
102 |
+
elif p < 0:
|
103 |
+
fbank = fbank[0:target_length, :]
|
104 |
+
|
105 |
+
if aug:
|
106 |
+
freqm = torchaudio.transforms.FrequencyMasking(48)
|
107 |
+
timem = torchaudio.transforms.TimeMasking(192)
|
108 |
+
fbank = torch.transpose(fbank, 0, 1)
|
109 |
+
fbank = fbank.unsqueeze(0)
|
110 |
+
fbank = freqm(fbank)
|
111 |
+
fbank = timem(fbank)
|
112 |
+
fbank = fbank.squeeze(0)
|
113 |
+
fbank = torch.transpose(fbank, 0, 1)
|
114 |
+
|
115 |
+
fbank = (fbank - (-4.2677393)) / (4.5689974 * 2)
|
116 |
+
return fbank
|
117 |
+
|
118 |
+
|
119 |
+
class ConversationGenerator:
|
120 |
+
def __init__(self, tokenizer):
|
121 |
+
self.tokenizer = tokenizer
|
122 |
+
self.header = f"{conversation_lib.default_conversation.system}\n\n"
|
123 |
+
self._probe_tokenizer_style()
|
124 |
+
|
125 |
+
def _probe_tokenizer_style(self):
|
126 |
+
"""
|
127 |
+
Given a sentence, e.g. "My darling", some tokenizers will make the space a seperate token,
|
128 |
+
while some others will merge the space into the next word, forming a token representing " darling".
|
129 |
+
Knowing which style the tokenizer takes is necessary for correct ground-truth label masking.
|
130 |
+
|
131 |
+
"""
|
132 |
+
probe = "Probe am I"
|
133 |
+
sentence1 = self.tokenizer.encode(conversation_lib.default_conversation.roles[1] + ": " + probe,
|
134 |
+
bos=False, eos=False)
|
135 |
+
sentence2 = self.tokenizer.encode(probe,
|
136 |
+
bos=False, eos=False)
|
137 |
+
if sentence1[-len(sentence2):] == sentence2:
|
138 |
+
self.space_before_to_predict = False
|
139 |
+
else:
|
140 |
+
sentence3 = self.tokenizer.encode(" " + probe,
|
141 |
+
bos=False, eos=False)
|
142 |
+
assert sentence1[-len(sentence3):] == sentence3
|
143 |
+
self.space_before_to_predict = True
|
144 |
+
|
145 |
+
def add_speaker_and_signal(self, source, get_conversation=True):
|
146 |
+
"""Add speaker and start/end signal on each round."""
|
147 |
+
BEGIN_SIGNAL = "### "
|
148 |
+
END_SIGNAL = "\n"
|
149 |
+
conversation = self.header
|
150 |
+
|
151 |
+
to_predict_list = []
|
152 |
+
|
153 |
+
for sentence in source:
|
154 |
+
from_str = sentence["from"]
|
155 |
+
if from_str.lower() in ["human"]:
|
156 |
+
from_str = conversation_lib.default_conversation.roles[0]
|
157 |
+
elif from_str.lower() in ["gpt", "assistant"]:
|
158 |
+
from_str = conversation_lib.default_conversation.roles[1]
|
159 |
+
else:
|
160 |
+
raise ValueError(f"unknown dialog role: {from_str.lower()}")
|
161 |
+
|
162 |
+
value = sentence["value"]
|
163 |
+
if DEFAULT_IMAGE_TOKEN in value:
|
164 |
+
value = value.replace(DEFAULT_IMAGE_TOKEN, '').strip()
|
165 |
+
|
166 |
+
sentence_value = BEGIN_SIGNAL + from_str + ": " + value + END_SIGNAL
|
167 |
+
|
168 |
+
if from_str == conversation_lib.default_conversation.roles[1]:
|
169 |
+
to_predict_value = value + END_SIGNAL + "###"
|
170 |
+
if self.space_before_to_predict:
|
171 |
+
to_predict_value = " " + to_predict_value
|
172 |
+
to_predict_list.append(to_predict_value)
|
173 |
+
|
174 |
+
if get_conversation:
|
175 |
+
conversation = conversation + sentence_value
|
176 |
+
|
177 |
+
conversation = conversation + BEGIN_SIGNAL
|
178 |
+
return conversation, to_predict_list
|
179 |
+
|
180 |
+
|
181 |
+
DATASETS = dict(
|
182 |
+
image=[
|
183 |
+
dict(path="datasets/InstructionTuning/image/llava_v1_5_mix665k_image.json", type='image'),
|
184 |
+
dict(path='datasets/InstructionTuning/image/cococap_train.json', type='image'),
|
185 |
+
dict(path="datasets/InstructionTuning/image/llava_v1_5_mix665k_text.json", type='text'),
|
186 |
+
],
|
187 |
+
audio=[
|
188 |
+
dict(path="datasets/InstructionTuning/audio/audiocap_train.json", type='audio'),
|
189 |
+
dict(path="datasets/InstructionTuning/audio/audiocap_val.json", type='audio'),
|
190 |
+
dict(path="datasets/InstructionTuning/audio/audio_conversation.json", type='audio'),
|
191 |
+
],
|
192 |
+
video=[
|
193 |
+
dict(path="datasets/InstructionTuning/video/msrvtt_cap_trainval.json", type='video'),
|
194 |
+
dict(path="datasets/InstructionTuning/video/msrvtt_cap_test.json", type='video'),
|
195 |
+
dict(path="datasets/InstructionTuning/video/msrvtt_vqa_train.json", type='video'),
|
196 |
+
dict(path="datasets/InstructionTuning/video/msrvtt_vqa_val.json", type='video'),
|
197 |
+
dict(path="datasets/InstructionTuning/video/msrvtt_vqa_test.json", type='video'),
|
198 |
+
dict(path="datasets/InstructionTuning/video/video_complex_reasoning_10k.json", type='video'),
|
199 |
+
dict(path="datasets/InstructionTuning/video/video_conversation_10k.json", type='video'),
|
200 |
+
dict(path="datasets/InstructionTuning/video/video_detail_10k.json", type='video'),
|
201 |
+
],
|
202 |
+
point=[
|
203 |
+
dict(path="datasets/InstructionTuning/point/pointllm_70k_formated.json", type='point'),
|
204 |
+
],
|
205 |
+
rgbd=[
|
206 |
+
dict(path="datasets/InstructionTuning/depth_normal/llava_instruct_50k_depth.json", type='rgbd'),
|
207 |
+
],
|
208 |
+
rgbn=[
|
209 |
+
dict(path="datasets/InstructionTuning/depth_normal/llava_instruct_50k_normal.json", type='rgbn'),
|
210 |
+
],
|
211 |
+
imu=[
|
212 |
+
dict(path="datasets/InstructionTuning/imu/imu_fixed_50k.json", type='imu'),
|
213 |
+
],
|
214 |
+
fmri=[
|
215 |
+
dict(path="datasets/InstructionTuning/fmri/fmri_fixed.json", type='fmri'),
|
216 |
+
],
|
217 |
+
)
|
218 |
+
IMU_PATH = "/mnt/petrelfs/share_data/hanjiaming/ego4d/v2/processed_imu/"
|
219 |
+
|
220 |
+
|
221 |
+
class FinetuneDialogDataset(Dataset):
|
222 |
+
def __init__(self, dataset=['image'], transform=T_random_resized_crop, max_words=2048, image_words=30, tokenizer_path=None):
|
223 |
+
if isinstance(dataset, str):
|
224 |
+
dataset = [dataset]
|
225 |
+
|
226 |
+
self.dataset = dataset
|
227 |
+
|
228 |
+
group_ann = {}
|
229 |
+
for d in dataset:
|
230 |
+
for meta in DATASETS[d]:
|
231 |
+
meta_path, meta_type = meta['path'], meta['type']
|
232 |
+
meta_ext = os.path.splitext(meta_path)[-1]
|
233 |
+
if meta_ext == ".json":
|
234 |
+
with open(meta_path) as f:
|
235 |
+
meta_l = json.load(f)
|
236 |
+
# add data_type
|
237 |
+
# this is a temp solution
|
238 |
+
new_meta_l = []
|
239 |
+
for l in meta_l:
|
240 |
+
l['data_type'] = meta_type
|
241 |
+
new_meta_l.append(l)
|
242 |
+
meta_l = new_meta_l
|
243 |
+
elif meta_ext == ".jsonl":
|
244 |
+
meta_l = []
|
245 |
+
with open(meta_path) as f:
|
246 |
+
for i, line in enumerate(f):
|
247 |
+
try:
|
248 |
+
meta_l.append(json.loads(line))
|
249 |
+
except json.decoder.JSONDecodeError as e:
|
250 |
+
print(
|
251 |
+
f"Error decoding the following jsonl line ({i}):\n{line.rstrip()}", force=True)
|
252 |
+
raise e
|
253 |
+
else:
|
254 |
+
raise NotImplementedError(
|
255 |
+
f"Unknown meta file extension: \"{meta_ext}\". "
|
256 |
+
f"Currently, .json, .jsonl are supported. "
|
257 |
+
"If you are using a supported format, please set the file extension so that the proper parsing "
|
258 |
+
"routine can be called."
|
259 |
+
)
|
260 |
+
if meta_type not in group_ann:
|
261 |
+
group_ann[meta_type] = []
|
262 |
+
print(f"{meta_path}, type {meta_type}: len {len(meta_l)}")
|
263 |
+
group_ann[meta_type] += meta_l
|
264 |
+
|
265 |
+
# sort group_ann for higher efficiency (items in one global batch with similar length)
|
266 |
+
for meta_type, meta_l in group_ann.items():
|
267 |
+
meta_l.sort(key=lambda data_item: sum(
|
268 |
+
[len(_['value']) for _ in data_item['conversations']]))
|
269 |
+
|
270 |
+
self.group_ann = group_ann
|
271 |
+
self.ann = sum(list(self.group_ann.values()), start=[])
|
272 |
+
|
273 |
+
self.group_indices = {}
|
274 |
+
start_pos = 0
|
275 |
+
for meta_type, meta_l in self.group_ann.items():
|
276 |
+
self.group_indices[meta_type] = list(
|
277 |
+
range(start_pos, start_pos + len(meta_l)))
|
278 |
+
start_pos = start_pos + len(meta_l)
|
279 |
+
|
280 |
+
print(f"total length: {len(self)}")
|
281 |
+
self.transform = transform
|
282 |
+
print(f"transform:\n{self.transform}")
|
283 |
+
self.max_words = max_words
|
284 |
+
self.image_words = image_words
|
285 |
+
self.tokenizer = Tokenizer(model_path=tokenizer_path)
|
286 |
+
self.conversation_generator = ConversationGenerator(self.tokenizer)
|
287 |
+
|
288 |
+
self.load_funcs = dict(
|
289 |
+
image=self.load_image,
|
290 |
+
audio=self.load_audio,
|
291 |
+
video=self.load_video,
|
292 |
+
point=self.load_point,
|
293 |
+
rgbd=self.load_rgbx,
|
294 |
+
rgbn=self.load_rgbx,
|
295 |
+
imu=self.load_imu,
|
296 |
+
fmri=self.load_fmri
|
297 |
+
)
|
298 |
+
|
299 |
+
def __len__(self):
|
300 |
+
return len(self.ann)
|
301 |
+
|
302 |
+
def load_image(self, data):
|
303 |
+
filename = data['image']
|
304 |
+
image = Image.open(filename).convert('RGB')
|
305 |
+
image = self.transform(image)
|
306 |
+
return image
|
307 |
+
|
308 |
+
def load_audio(self, data):
|
309 |
+
audio_path = data['image']
|
310 |
+
fbank = make_audio_features(audio_path, mel_bins=128)
|
311 |
+
fbank = fbank.transpose(0, 1)[None] # [1, 128, 1024]
|
312 |
+
return fbank
|
313 |
+
|
314 |
+
def load_video(self, data):
|
315 |
+
video_path = data['image']
|
316 |
+
video_feats = video_utils.load_and_transform_video_data(
|
317 |
+
video_path, video_path, clip_duration=1, clips_per_video=5)
|
318 |
+
return video_feats[:, :, 0]
|
319 |
+
|
320 |
+
def load_point(self, data):
|
321 |
+
point_path = data['image']
|
322 |
+
point_feat = torch.load(point_path, map_location='cpu')
|
323 |
+
point_feat = point_feat.transpose(0, 1)
|
324 |
+
return point_feat
|
325 |
+
|
326 |
+
def load_rgbx(self, data):
|
327 |
+
image_path = data['image']
|
328 |
+
x_image_path = data['depth_image'] if 'depth_image' in data else data['normal_image']
|
329 |
+
image = Image.open(image_path).convert('RGB')
|
330 |
+
x_image = Image.open(x_image_path).convert('RGB')
|
331 |
+
x_image = x_image.resize(image.size[-2:])
|
332 |
+
|
333 |
+
image, x_image = transform_pairimg_train([image, x_image])
|
334 |
+
# [2, 3, H, W]
|
335 |
+
image = torch.stack([image, x_image], dim=0)
|
336 |
+
return image
|
337 |
+
|
338 |
+
def load_fmri(self, data):
|
339 |
+
fmri_path = data['image']
|
340 |
+
data = np.load(fmri_path)
|
341 |
+
data = data.mean(axis=0)
|
342 |
+
data = torch.tensor(data[None])
|
343 |
+
return data
|
344 |
+
|
345 |
+
def load_imu(self, data_dict):
|
346 |
+
uid = data_dict["video_uid"]
|
347 |
+
w_s = data_dict["window_start"]
|
348 |
+
w_e = data_dict["window_end"]
|
349 |
+
|
350 |
+
imu_data = get_imu_frames(
|
351 |
+
IMU_PATH, uid,
|
352 |
+
video_start_sec=w_s,
|
353 |
+
video_end_sec=w_e,
|
354 |
+
)
|
355 |
+
if imu_data is None:
|
356 |
+
raise ValueError
|
357 |
+
return imu_data['signal']
|
358 |
+
|
359 |
+
def __getitem__(self, index, expect_type=None):
|
360 |
+
if expect_type is None:
|
361 |
+
data_item = self.ann[index]
|
362 |
+
else:
|
363 |
+
# in case we want get data from specific data_type
|
364 |
+
data_item = self.group_ann[expect_type][index]
|
365 |
+
|
366 |
+
data_type = data_item['data_type']
|
367 |
+
if data_type != 'text':
|
368 |
+
if data_type in self.load_funcs:
|
369 |
+
try:
|
370 |
+
image = self.load_funcs[data_type](data_item)
|
371 |
+
if image == None:
|
372 |
+
raise ValueError('Data is None')
|
373 |
+
except:
|
374 |
+
print('Error', data_item)
|
375 |
+
rand_idx = random.randint(
|
376 |
+
0, len(self.group_ann[data_type]))
|
377 |
+
return self.__getitem__(rand_idx, expect_type=data_type)
|
378 |
+
else:
|
379 |
+
raise ValueError(f'Does not support {data_type}')
|
380 |
+
else:
|
381 |
+
image = None
|
382 |
+
# warnings.warn("pure black image for examples without image")
|
383 |
+
# image = torch.zeros(3, 224, 224)
|
384 |
+
|
385 |
+
source = data_item["conversations"]
|
386 |
+
conversation, to_predict_values = self.conversation_generator.add_speaker_and_signal(
|
387 |
+
source)
|
388 |
+
if len(to_predict_values) == 0:
|
389 |
+
warnings.warn(
|
390 |
+
f"see dialog data with nothing to predict, data: {data_item}")
|
391 |
+
return self[index-1]
|
392 |
+
|
393 |
+
tokenzed_conversation = self.tokenizer.encode(
|
394 |
+
conversation, bos=True, eos=True)
|
395 |
+
labels = [IGNORE_INDEX for _ in tokenzed_conversation]
|
396 |
+
|
397 |
+
check_pos = 0
|
398 |
+
for value in to_predict_values:
|
399 |
+
tokenized_value = self.tokenizer.encode(
|
400 |
+
value, bos=False, eos=False)
|
401 |
+
value_pos = find_sublist(
|
402 |
+
tokenzed_conversation[check_pos:], tokenized_value) + check_pos
|
403 |
+
if value_pos == -1:
|
404 |
+
print(
|
405 |
+
"a sentence mismatches the corresponding piece in the conversation")
|
406 |
+
return self[index-1]
|
407 |
+
labels[value_pos:value_pos+len(tokenized_value)] = tokenized_value
|
408 |
+
assert labels[value_pos:value_pos+len(
|
409 |
+
tokenized_value)] == tokenzed_conversation[value_pos:value_pos+len(tokenized_value)]
|
410 |
+
check_pos = value_pos+len(tokenized_value)
|
411 |
+
|
412 |
+
input2 = torch.tensor(tokenzed_conversation, dtype=torch.int64)
|
413 |
+
labels = torch.tensor(labels, dtype=torch.int64)
|
414 |
+
|
415 |
+
if image is not None:
|
416 |
+
max_words = self.max_words - self.image_words
|
417 |
+
else:
|
418 |
+
max_words = self.max_words
|
419 |
+
padding = max_words - input2.shape[0]
|
420 |
+
if padding > 0:
|
421 |
+
input2 = torch.cat(
|
422 |
+
(input2, torch.zeros(padding, dtype=torch.int64) - 1))
|
423 |
+
labels = torch.cat(
|
424 |
+
(labels, torch.zeros(padding, dtype=torch.int64) - 1))
|
425 |
+
elif padding < 0:
|
426 |
+
input2 = input2[:max_words]
|
427 |
+
labels = labels[:max_words]
|
428 |
+
|
429 |
+
input2_mask = input2.ge(0)
|
430 |
+
label_mask = labels.ge(0)
|
431 |
+
input2[~input2_mask] = 0
|
432 |
+
labels[~label_mask] = 0
|
433 |
+
input2_mask = input2_mask.float()
|
434 |
+
label_mask = label_mask.float()
|
435 |
+
if image is None:
|
436 |
+
return input2, labels, data_item['data_type']
|
437 |
+
else:
|
438 |
+
return input2, labels, image, data_item['data_type']
|
439 |
+
|
440 |
+
def groups(self):
|
441 |
+
return list(self.group_indices.values())
|
442 |
+
|
443 |
+
|
444 |
+
def find_sublist(a: list, b: list):
|
445 |
+
len_a, len_b = len(a), len(b)
|
446 |
+
for i in range(len_a - len_b + 1):
|
447 |
+
if a[i:i+len_b] == b:
|
448 |
+
return i
|
449 |
+
return -1
|
data/imu_utils.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import string
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.animation as animation
|
4 |
+
from matplotlib import pyplot as plt
|
5 |
+
import json
|
6 |
+
from collections import defaultdict
|
7 |
+
from bisect import bisect_left
|
8 |
+
import os
|
9 |
+
import torch
|
10 |
+
import torchaudio
|
11 |
+
torchaudio.set_audio_backend("sox_io")
|
12 |
+
|
13 |
+
|
14 |
+
def load_json(json_path: str):
|
15 |
+
"""
|
16 |
+
Load a json file
|
17 |
+
"""
|
18 |
+
with open(json_path, "r", encoding="utf-8") as f_name:
|
19 |
+
data = json.load(f_name)
|
20 |
+
return data
|
21 |
+
|
22 |
+
|
23 |
+
def check_window_signal(info_t, w_s, w_e):
|
24 |
+
length = w_e - w_s
|
25 |
+
frame_offset = int(w_s * info_t.sample_rate)
|
26 |
+
num_frames = int(length * info_t.sample_rate)
|
27 |
+
if frame_offset + num_frames > int(info_t.num_frames):
|
28 |
+
return False
|
29 |
+
else:
|
30 |
+
return True
|
31 |
+
|
32 |
+
|
33 |
+
def index_narrations(ann_path):
|
34 |
+
narration_raw = load_json(ann_path)
|
35 |
+
|
36 |
+
narration_dict = defaultdict(list)
|
37 |
+
summary_dict = defaultdict(list)
|
38 |
+
avg_len = []
|
39 |
+
for v_id, narr in narration_raw.items():
|
40 |
+
narr_list = []
|
41 |
+
summ_list = []
|
42 |
+
if "narration_pass_1" in narr:
|
43 |
+
narr_list += narr["narration_pass_1"]["narrations"]
|
44 |
+
summ_list += narr["narration_pass_1"]["summaries"]
|
45 |
+
if "narration_pass_2" in narr:
|
46 |
+
narr_list += narr["narration_pass_2"]["narrations"]
|
47 |
+
summ_list += narr["narration_pass_2"]["summaries"]
|
48 |
+
|
49 |
+
if len(narr_list) > 0:
|
50 |
+
narration_dict[v_id] = [
|
51 |
+
(
|
52 |
+
float(n_t["timestamp_sec"]),
|
53 |
+
n_t["narration_text"],
|
54 |
+
n_t["annotation_uid"],
|
55 |
+
n_t["timestamp_frame"],
|
56 |
+
)
|
57 |
+
for n_t in narr_list
|
58 |
+
]
|
59 |
+
avg_len.append(len(narration_dict[v_id]))
|
60 |
+
else:
|
61 |
+
narration_dict[v_id] = []
|
62 |
+
if len(summ_list) > 0:
|
63 |
+
summary_dict[v_id] = [
|
64 |
+
(
|
65 |
+
float(s_t["start_sec"]),
|
66 |
+
float(s_t["end_sec"]),
|
67 |
+
s_t["summary_text"],
|
68 |
+
)
|
69 |
+
for s_t in summ_list
|
70 |
+
]
|
71 |
+
else:
|
72 |
+
summary_dict[v_id] = []
|
73 |
+
# print(f"Number of Videos with narration {len(narration_dict)}")
|
74 |
+
# print(f"Avg. narration length {np.mean(avg_len)}")
|
75 |
+
# print(f"Number of Videos with summaries {len(summary_dict)}")
|
76 |
+
return narration_dict, summary_dict
|
77 |
+
|
78 |
+
|
79 |
+
def get_signal_info(signal_fn: str):
|
80 |
+
return torchaudio.info(signal_fn)
|
81 |
+
|
82 |
+
|
83 |
+
def get_signal_frames(signal_fn: str, video_start_sec: float, video_end_sec: float):
|
84 |
+
"""
|
85 |
+
Given a signal track return the frames between video_start_sec and video_end_sec
|
86 |
+
"""
|
87 |
+
info_t = get_signal_info(signal_fn)
|
88 |
+
|
89 |
+
length = video_end_sec - video_start_sec
|
90 |
+
aframes, _ = torchaudio.load(
|
91 |
+
signal_fn,
|
92 |
+
normalize=True,
|
93 |
+
frame_offset=int(video_start_sec * info_t.sample_rate),
|
94 |
+
num_frames=int(length * info_t.sample_rate),
|
95 |
+
)
|
96 |
+
return {"signal": aframes, "meta": info_t}
|
97 |
+
|
98 |
+
|
99 |
+
def tosec(value):
|
100 |
+
return value / 1000
|
101 |
+
|
102 |
+
|
103 |
+
def toms(value):
|
104 |
+
return value * 1000
|
105 |
+
|
106 |
+
|
107 |
+
def delta(first_num: float, second_num: float):
|
108 |
+
"""Compute the absolute value of the difference of two numbers"""
|
109 |
+
return abs(first_num - second_num)
|
110 |
+
|
111 |
+
|
112 |
+
def padIMU(signal, duration_sec):
|
113 |
+
"""
|
114 |
+
Pad the signal if necessary
|
115 |
+
"""
|
116 |
+
expected_elements = round(duration_sec) * 200
|
117 |
+
|
118 |
+
if signal.shape[0] > expected_elements:
|
119 |
+
signal = signal[:expected_elements, :]
|
120 |
+
elif signal.shape[0] < expected_elements:
|
121 |
+
padding = expected_elements - signal.shape[0]
|
122 |
+
padded_zeros = np.zeros((padding, 6))
|
123 |
+
signal = np.concatenate([signal, padded_zeros], 0)
|
124 |
+
# signal = signal[:expected_elements, :]
|
125 |
+
return signal
|
126 |
+
|
127 |
+
|
128 |
+
def resample(
|
129 |
+
signals: np.ndarray,
|
130 |
+
timestamps: np.ndarray,
|
131 |
+
original_sample_rate: int,
|
132 |
+
resample_rate: int,
|
133 |
+
):
|
134 |
+
"""
|
135 |
+
Resamples data to new sample rate
|
136 |
+
"""
|
137 |
+
signals = torch.as_tensor(signals)
|
138 |
+
timestamps = torch.from_numpy(timestamps).unsqueeze(-1)
|
139 |
+
signals = torchaudio.functional.resample(
|
140 |
+
waveform=signals.data.T,
|
141 |
+
orig_freq=original_sample_rate,
|
142 |
+
new_freq=resample_rate,
|
143 |
+
).T.numpy()
|
144 |
+
|
145 |
+
nsamples = len(signals)
|
146 |
+
|
147 |
+
period = 1 / resample_rate
|
148 |
+
|
149 |
+
# timestamps are expected to be shape (N, 1)
|
150 |
+
initital_seconds = timestamps[0] / 1e3
|
151 |
+
|
152 |
+
ntimes = (torch.arange(nsamples) * period).view(-1, 1) + initital_seconds
|
153 |
+
|
154 |
+
timestamps = (ntimes * 1e3).squeeze().numpy()
|
155 |
+
return signals, timestamps
|
156 |
+
|
157 |
+
|
158 |
+
def resampleIMU(signal, timestamps):
|
159 |
+
sampling_rate = int(1000 * (1 / (np.mean(np.diff(timestamps)))))
|
160 |
+
# resample all to 200hz
|
161 |
+
if sampling_rate != 200:
|
162 |
+
signal, timestamps = resample(signal, timestamps, sampling_rate, 200)
|
163 |
+
return signal, timestamps
|
164 |
+
|
165 |
+
|
166 |
+
def get_imu_frames(
|
167 |
+
imu_path,
|
168 |
+
uid: str,
|
169 |
+
video_start_sec: float,
|
170 |
+
video_end_sec: float,
|
171 |
+
):
|
172 |
+
"""
|
173 |
+
Given a IMU signal return the frames between video_start_sec and video_end_sec
|
174 |
+
"""
|
175 |
+
signal = np.load(os.path.join(imu_path, f"{uid}.npy"))
|
176 |
+
signal = signal.transpose()
|
177 |
+
timestamps = np.load(os.path.join(imu_path, f"{uid}_timestamps.npy"))
|
178 |
+
|
179 |
+
if toms(video_start_sec) > timestamps[-1] or toms(video_end_sec) > timestamps[-1]:
|
180 |
+
return None
|
181 |
+
|
182 |
+
start_id = bisect_left(timestamps, toms(video_start_sec))
|
183 |
+
end_id = bisect_left(timestamps, toms(video_end_sec))
|
184 |
+
|
185 |
+
# make sure the retrieved window interval are correct by a max of 1 sec margin
|
186 |
+
if (
|
187 |
+
delta(video_start_sec, tosec(timestamps[start_id])) > 4
|
188 |
+
or delta(video_end_sec, tosec(timestamps[end_id])) > 4
|
189 |
+
):
|
190 |
+
return None
|
191 |
+
|
192 |
+
# get the window
|
193 |
+
if start_id == end_id:
|
194 |
+
start_id -= 1
|
195 |
+
end_id += 1
|
196 |
+
signal, timestamps = signal[start_id:end_id], timestamps[start_id:end_id]
|
197 |
+
|
198 |
+
if len(signal) < 10 or len(timestamps) < 10:
|
199 |
+
return None
|
200 |
+
# resample the signal at 200hz if necessary
|
201 |
+
signal, timestamps = resampleIMU(signal, timestamps)
|
202 |
+
|
203 |
+
# pad the signal if necessary
|
204 |
+
signal = padIMU(signal, video_end_sec - video_start_sec)
|
205 |
+
|
206 |
+
sample_dict = {
|
207 |
+
"timestamp": timestamps,
|
208 |
+
"signal": torch.tensor(signal.T),
|
209 |
+
"sampling_rate": 200,
|
210 |
+
}
|
211 |
+
|
212 |
+
return sample_dict
|
213 |
+
|
214 |
+
|
215 |
+
def display_animation(frames, title, save_path_gif):
|
216 |
+
fig, ax = plt.subplots()
|
217 |
+
frames = [[ax.imshow(frames[i])] for i in range(len(frames))]
|
218 |
+
plt.title(title)
|
219 |
+
ani = animation.ArtistAnimation(fig, frames)
|
220 |
+
ani.save(save_path_gif, writer="imagemagick")
|
221 |
+
plt.close()
|
222 |
+
|
223 |
+
|
224 |
+
def display_animation_imu(frames, imu, title, save_path_gif):
|
225 |
+
fig, (ax1, ax2, ax3) = plt.subplots(3, 1)
|
226 |
+
ax1.set_title(title)
|
227 |
+
ax2.set_title("Acc.")
|
228 |
+
ax3.set_title("Gyro.")
|
229 |
+
frames = [[ax1.imshow(frames[i])] for i in range(len(frames))]
|
230 |
+
ani = animation.ArtistAnimation(fig, frames)
|
231 |
+
|
232 |
+
ax2.plot(imu[0].cpu().numpy(), color="red")
|
233 |
+
ax2.plot(imu[1].cpu().numpy(), color="blue")
|
234 |
+
ax2.plot(imu[2].cpu().numpy(), color="green")
|
235 |
+
ax3.plot(imu[3].cpu().numpy(), color="red")
|
236 |
+
ax3.plot(imu[4].cpu().numpy(), color="blue")
|
237 |
+
ax3.plot(imu[5].cpu().numpy(), color="green")
|
238 |
+
plt.tight_layout()
|
239 |
+
ani.save(save_path_gif, writer="imagemagick")
|
240 |
+
plt.close()
|
241 |
+
|
242 |
+
|
243 |
+
def filter_narration(narration_text: str) -> bool:
|
244 |
+
if "#c" in narration_text.lower():
|
245 |
+
return True
|
246 |
+
return False
|
247 |
+
|
248 |
+
|
249 |
+
def clean_narration_text(narration_text: str) -> str:
|
250 |
+
return (
|
251 |
+
narration_text.replace("#C C ", "")
|
252 |
+
.replace("#C", "")
|
253 |
+
.replace("#unsure", "something")
|
254 |
+
.strip()
|
255 |
+
.strip(string.punctuation)
|
256 |
+
.lower()[:128]
|
257 |
+
)
|
data/video_utils.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from pytorchvideo import transforms as pv_transforms
|
5 |
+
from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler
|
6 |
+
from pytorchvideo.data.encoded_video import EncodedVideo
|
7 |
+
from pytorchvideo.data.encoded_video_decord import EncodedVideoDecord
|
8 |
+
from torchvision import transforms
|
9 |
+
from torchvision.transforms._transforms_video import NormalizeVideo
|
10 |
+
|
11 |
+
|
12 |
+
def get_clip_timepoints(clip_sampler, duration):
|
13 |
+
# Read out all clips in this video
|
14 |
+
all_clips_timepoints = []
|
15 |
+
is_last_clip = False
|
16 |
+
end = 0.0
|
17 |
+
while not is_last_clip:
|
18 |
+
start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
|
19 |
+
all_clips_timepoints.append((start, end))
|
20 |
+
return all_clips_timepoints
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
def crop_boxes(boxes, x_offset, y_offset):
|
25 |
+
"""
|
26 |
+
Perform crop on the bounding boxes given the offsets.
|
27 |
+
Args:
|
28 |
+
boxes (ndarray or None): bounding boxes to perform crop. The dimension
|
29 |
+
is `num boxes` x 4.
|
30 |
+
x_offset (int): cropping offset in the x axis.
|
31 |
+
y_offset (int): cropping offset in the y axis.
|
32 |
+
Returns:
|
33 |
+
cropped_boxes (ndarray or None): the cropped boxes with dimension of
|
34 |
+
`num boxes` x 4.
|
35 |
+
"""
|
36 |
+
cropped_boxes = boxes.copy()
|
37 |
+
cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
|
38 |
+
cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
|
39 |
+
|
40 |
+
return cropped_boxes
|
41 |
+
|
42 |
+
|
43 |
+
def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
|
44 |
+
"""
|
45 |
+
Perform uniform spatial sampling on the images and corresponding boxes.
|
46 |
+
Args:
|
47 |
+
images (tensor): images to perform uniform crop. The dimension is
|
48 |
+
`num frames` x `channel` x `height` x `width`.
|
49 |
+
size (int): size of height and weight to crop the images.
|
50 |
+
spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
|
51 |
+
is larger than height. Or 0, 1, or 2 for top, center, and bottom
|
52 |
+
crop if height is larger than width.
|
53 |
+
boxes (ndarray or None): optional. Corresponding boxes to images.
|
54 |
+
Dimension is `num boxes` x 4.
|
55 |
+
scale_size (int): optinal. If not None, resize the images to scale_size before
|
56 |
+
performing any crop.
|
57 |
+
Returns:
|
58 |
+
cropped (tensor): images with dimension of
|
59 |
+
`num frames` x `channel` x `size` x `size`.
|
60 |
+
cropped_boxes (ndarray or None): the cropped boxes with dimension of
|
61 |
+
`num boxes` x 4.
|
62 |
+
"""
|
63 |
+
assert spatial_idx in [0, 1, 2]
|
64 |
+
ndim = len(images.shape)
|
65 |
+
if ndim == 3:
|
66 |
+
images = images.unsqueeze(0)
|
67 |
+
height = images.shape[2]
|
68 |
+
width = images.shape[3]
|
69 |
+
|
70 |
+
if scale_size is not None:
|
71 |
+
if width <= height:
|
72 |
+
width, height = scale_size, int(height / width * scale_size)
|
73 |
+
else:
|
74 |
+
width, height = int(width / height * scale_size), scale_size
|
75 |
+
images = torch.nn.functional.interpolate(
|
76 |
+
images,
|
77 |
+
size=(height, width),
|
78 |
+
mode="bilinear",
|
79 |
+
align_corners=False,
|
80 |
+
)
|
81 |
+
|
82 |
+
y_offset = int(math.ceil((height - size) / 2))
|
83 |
+
x_offset = int(math.ceil((width - size) / 2))
|
84 |
+
|
85 |
+
if height > width:
|
86 |
+
if spatial_idx == 0:
|
87 |
+
y_offset = 0
|
88 |
+
elif spatial_idx == 2:
|
89 |
+
y_offset = height - size
|
90 |
+
else:
|
91 |
+
if spatial_idx == 0:
|
92 |
+
x_offset = 0
|
93 |
+
elif spatial_idx == 2:
|
94 |
+
x_offset = width - size
|
95 |
+
cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
|
96 |
+
cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
|
97 |
+
if ndim == 3:
|
98 |
+
cropped = cropped.squeeze(0)
|
99 |
+
return cropped, cropped_boxes
|
100 |
+
|
101 |
+
|
102 |
+
class SpatialCrop(nn.Module):
|
103 |
+
"""
|
104 |
+
Convert the video into 3 smaller clips spatially. Must be used after the
|
105 |
+
temporal crops to get spatial crops, and should be used with
|
106 |
+
-2 in the spatial crop at the slowfast augmentation stage (so full
|
107 |
+
frames are passed in here). Will return a larger list with the
|
108 |
+
3x spatial crops as well.
|
109 |
+
"""
|
110 |
+
|
111 |
+
def __init__(self, crop_size: int = 224, num_crops: int = 3):
|
112 |
+
super().__init__()
|
113 |
+
self.crop_size = crop_size
|
114 |
+
if num_crops == 3:
|
115 |
+
self.crops_to_ext = [0, 1, 2]
|
116 |
+
self.flipped_crops_to_ext = []
|
117 |
+
elif num_crops == 1:
|
118 |
+
self.crops_to_ext = [1]
|
119 |
+
self.flipped_crops_to_ext = []
|
120 |
+
else:
|
121 |
+
raise NotImplementedError("Nothing else supported yet")
|
122 |
+
|
123 |
+
def forward(self, videos):
|
124 |
+
"""
|
125 |
+
Args:
|
126 |
+
videos: A list of C, T, H, W videos.
|
127 |
+
Returns:
|
128 |
+
videos: A list with 3x the number of elements. Each video converted
|
129 |
+
to C, T, H', W' by spatial cropping.
|
130 |
+
"""
|
131 |
+
assert isinstance(videos, list), "Must be a list of videos after temporal crops"
|
132 |
+
assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)"
|
133 |
+
res = []
|
134 |
+
for video in videos:
|
135 |
+
for spatial_idx in self.crops_to_ext:
|
136 |
+
res.append(uniform_crop(video, self.crop_size, spatial_idx)[0])
|
137 |
+
if not self.flipped_crops_to_ext:
|
138 |
+
continue
|
139 |
+
flipped_video = transforms.functional.hflip(video)
|
140 |
+
for spatial_idx in self.flipped_crops_to_ext:
|
141 |
+
res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0])
|
142 |
+
return res
|
143 |
+
|
144 |
+
|
145 |
+
def load_and_transform_video_data(
|
146 |
+
video_file,
|
147 |
+
video_path,
|
148 |
+
clip_duration=2,
|
149 |
+
clips_per_video=5,
|
150 |
+
sample_rate=16000,
|
151 |
+
with_audio=False
|
152 |
+
):
|
153 |
+
video_transform = transforms.Compose(
|
154 |
+
[
|
155 |
+
pv_transforms.ShortSideScale(224),
|
156 |
+
NormalizeVideo(
|
157 |
+
mean=(0.48145466, 0.4578275, 0.40821073),
|
158 |
+
std=(0.26862954, 0.26130258, 0.27577711),
|
159 |
+
),
|
160 |
+
]
|
161 |
+
)
|
162 |
+
|
163 |
+
clip_sampler = ConstantClipsPerVideoSampler(
|
164 |
+
clip_duration=clip_duration, clips_per_video=clips_per_video
|
165 |
+
)
|
166 |
+
frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration)
|
167 |
+
|
168 |
+
if isinstance(video_file, str):
|
169 |
+
video = EncodedVideo.from_path(
|
170 |
+
video_file,
|
171 |
+
decoder="decord",
|
172 |
+
decode_audio=with_audio,
|
173 |
+
# **{"sample_rate": sample_rate},
|
174 |
+
)
|
175 |
+
else:
|
176 |
+
video = EncodedVideoDecord(video_file, video_name=video_path, decode_video=True, decode_audio=with_audio, sample_rate=sample_rate)
|
177 |
+
|
178 |
+
all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration)
|
179 |
+
|
180 |
+
all_video = []
|
181 |
+
for clip_timepoints in all_clips_timepoints:
|
182 |
+
# Read the clip, get frames
|
183 |
+
clip = video.get_clip(clip_timepoints[0], clip_timepoints[1])
|
184 |
+
if clip is None:
|
185 |
+
raise ValueError("No clip found")
|
186 |
+
video_clip = frame_sampler(clip["video"])
|
187 |
+
video_clip = video_clip / 255.0 # since this is float, need 0-1
|
188 |
+
|
189 |
+
all_video.append(video_clip)
|
190 |
+
|
191 |
+
all_video = [video_transform(clip) for clip in all_video]
|
192 |
+
all_video = SpatialCrop(224, num_crops=3)(all_video)
|
193 |
+
|
194 |
+
all_video = torch.stack(all_video, dim=0)
|
195 |
+
|
196 |
+
if not with_audio:
|
197 |
+
return all_video
|
198 |
+
else:
|
199 |
+
return all_video, clip['audio']
|
200 |
+
|
201 |
+
if __name__ == '__main__':
|
202 |
+
video_path = "datasets/InstructionTuning/video/music_aqa/MUSIC-AVQA-videos-Real/00000002.mp4"
|
203 |
+
video, audio = load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5, with_audio=True)
|
204 |
+
import pdb;pdb.set_trace()
|
demos/multi_turn_mm.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
sys.path.append(os.path.abspath(__file__).rsplit('/', 2)[0])
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import multiprocessing as mp
|
7 |
+
import numpy as np
|
8 |
+
from typing import List, Optional
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.distributed as dist
|
12 |
+
|
13 |
+
from fairscale.nn.model_parallel import initialize as fs_init
|
14 |
+
|
15 |
+
import gradio as gr
|
16 |
+
from util.misc import setup_for_distributed
|
17 |
+
from util.misc import default_tensor_type
|
18 |
+
from model.meta import MetaModel
|
19 |
+
from data.conversation_lib import conv_templates, SeparatorStyle
|
20 |
+
from PIL import Image
|
21 |
+
import torchvision.transforms as transforms
|
22 |
+
from data.fintune_dataset import make_audio_features
|
23 |
+
from data import video_utils
|
24 |
+
|
25 |
+
|
26 |
+
T_random_resized_crop = transforms.Compose([
|
27 |
+
transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=3,
|
28 |
+
antialias=None), # 3 is bicubic
|
29 |
+
transforms.ToTensor(),
|
30 |
+
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
|
31 |
+
|
32 |
+
|
33 |
+
def load_audio(audio_path):
|
34 |
+
fbank = make_audio_features(audio_path, mel_bins=128)
|
35 |
+
fbank = fbank.transpose(0, 1)[None] #[1, 128, 1024]
|
36 |
+
return fbank
|
37 |
+
|
38 |
+
def load_video(video_path):
|
39 |
+
video_feats = video_utils.load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5)
|
40 |
+
return video_feats[:, :, 0]
|
41 |
+
|
42 |
+
|
43 |
+
def model_worker(
|
44 |
+
rank: int, args: argparse.Namespace, barrier: mp.Barrier,
|
45 |
+
request_queue: mp.Queue, response_queue: Optional[mp.Queue] = None,
|
46 |
+
) -> None:
|
47 |
+
"""
|
48 |
+
The worker function that manipulates the GPU to run the inference.
|
49 |
+
Exact n_gpu workers are started, with each one operating on a separate GPU.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
rank (int): Distributed rank of the worker.
|
53 |
+
args (argparse.Namespace): All command line arguments.
|
54 |
+
barrier (multiprocessing.Barrier): A barrier used to delay the start
|
55 |
+
of Web UI to be after the start of the model.
|
56 |
+
"""
|
57 |
+
|
58 |
+
world_size = len(args.gpu_ids)
|
59 |
+
gpu_id = args.gpu_ids[rank]
|
60 |
+
dist.init_process_group(
|
61 |
+
backend="nccl", rank=rank, world_size=world_size,
|
62 |
+
init_method=f"tcp://{args.master_addr}:{args.master_port}",
|
63 |
+
)
|
64 |
+
print(f"| distributed init on worker {rank}/{world_size}. "
|
65 |
+
f"using gpu: {gpu_id}")
|
66 |
+
fs_init.initialize_model_parallel(world_size)
|
67 |
+
torch.cuda.set_device(gpu_id)
|
68 |
+
|
69 |
+
torch.manual_seed(1)
|
70 |
+
np.random.seed(1)
|
71 |
+
|
72 |
+
# set the print behavior.
|
73 |
+
setup_for_distributed(rank == 0)
|
74 |
+
|
75 |
+
target_dtype = {
|
76 |
+
"bf16": torch.bfloat16,
|
77 |
+
"fp16": torch.float16
|
78 |
+
}[args.dtype]
|
79 |
+
with default_tensor_type(dtype=target_dtype, device="cuda"):
|
80 |
+
model = MetaModel(args.llama_type, args.llama_config, tokenizer_path=args.tokenizer_path)
|
81 |
+
print("Loading pretrained weights ...")
|
82 |
+
checkpoint = torch.load(args.pretrained_path, map_location='cpu')
|
83 |
+
msg = model.load_state_dict(checkpoint, strict=False)
|
84 |
+
print("load result:\n", msg)
|
85 |
+
model.cuda()
|
86 |
+
model.eval()
|
87 |
+
print(f"Model = {str(model)}")
|
88 |
+
|
89 |
+
barrier.wait()
|
90 |
+
|
91 |
+
while True:
|
92 |
+
img_path, audio_path, video_path, chatbot, max_gen_len, temperature, top_p, modality = request_queue.get()
|
93 |
+
if 'image' in modality and img_path is not None:
|
94 |
+
image = Image.open(img_path).convert('RGB')
|
95 |
+
inputs = T_random_resized_crop(image)
|
96 |
+
elif 'video' in modality and video_path is not None:
|
97 |
+
inputs = load_video(video_path)
|
98 |
+
elif 'audio' in modality and audio_path is not None:
|
99 |
+
inputs = load_audio(audio_path)
|
100 |
+
else:
|
101 |
+
inputs = None
|
102 |
+
|
103 |
+
if inputs is not None:
|
104 |
+
inputs = inputs[None].cuda().to(target_dtype)
|
105 |
+
|
106 |
+
conv = conv_templates["v1"].copy()
|
107 |
+
for user, bot in chatbot:
|
108 |
+
conv.append_message(conv.roles[0], user)
|
109 |
+
conv.append_message(conv.roles[1], bot)
|
110 |
+
|
111 |
+
with torch.cuda.amp.autocast(dtype=target_dtype):
|
112 |
+
print(conv.get_prompt())
|
113 |
+
for stream_response in model.stream_generate(
|
114 |
+
conv.get_prompt(), inputs,
|
115 |
+
max_gen_len=max_gen_len, temperature=temperature, top_p=top_p,
|
116 |
+
modal = modality
|
117 |
+
):
|
118 |
+
conv_sep = (
|
119 |
+
conv.sep
|
120 |
+
if conv.sep_style == SeparatorStyle.SINGLE
|
121 |
+
else conv.sep2
|
122 |
+
)
|
123 |
+
end_pos = stream_response["text"].find(conv_sep)
|
124 |
+
if end_pos != -1:
|
125 |
+
stream_response["text"] = (
|
126 |
+
stream_response['text'][:end_pos].rstrip() + "\n"
|
127 |
+
)
|
128 |
+
stream_response["end_of_content"] = True
|
129 |
+
|
130 |
+
# keep a few characters if not end_of_content to avoid sending
|
131 |
+
# part of conv_sep before all of it is generated.
|
132 |
+
if not stream_response["end_of_content"]:
|
133 |
+
if len(stream_response["text"]) < len(conv_sep):
|
134 |
+
continue
|
135 |
+
stream_response["text"] = (
|
136 |
+
stream_response["text"][:-len(conv_sep)]
|
137 |
+
)
|
138 |
+
|
139 |
+
if response_queue is not None:
|
140 |
+
response_queue.put(stream_response)
|
141 |
+
|
142 |
+
if stream_response["end_of_content"]:
|
143 |
+
break
|
144 |
+
|
145 |
+
|
146 |
+
def gradio_worker(
|
147 |
+
request_queues: List[mp.Queue], response_queue: mp.Queue,
|
148 |
+
args: argparse.Namespace, barrier: mp.Barrier,
|
149 |
+
) -> None:
|
150 |
+
"""
|
151 |
+
The gradio worker is responsible for displaying the WebUI and relay the
|
152 |
+
requests to model workers. It should be launched only once.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
request_queues (List[mp.Queue]): A list of request queues (one for
|
156 |
+
each model worker).
|
157 |
+
args (argparse.Namespace): All command line arguments.
|
158 |
+
barrier (multiprocessing.Barrier): A barrier used to delay the start
|
159 |
+
of Web UI to be after the start of the model.
|
160 |
+
"""
|
161 |
+
|
162 |
+
def show_user_input(msg, chatbot):
|
163 |
+
return "", chatbot + [[msg, None]]
|
164 |
+
|
165 |
+
def stream_model_output(img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality):
|
166 |
+
for queue in request_queues:
|
167 |
+
queue.put((img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality))
|
168 |
+
while True:
|
169 |
+
content_piece = response_queue.get()
|
170 |
+
chatbot[-1][1] = content_piece["text"]
|
171 |
+
yield chatbot
|
172 |
+
if content_piece["end_of_content"]:
|
173 |
+
break
|
174 |
+
|
175 |
+
def undo(chatbot):
|
176 |
+
if len(chatbot) > 0:
|
177 |
+
chatbot = chatbot[:-1]
|
178 |
+
return chatbot
|
179 |
+
|
180 |
+
def clear():
|
181 |
+
chatbot = []
|
182 |
+
msg = ""
|
183 |
+
return chatbot, msg
|
184 |
+
|
185 |
+
CSS ="""
|
186 |
+
.contain { display: flex; flex-direction: column; }
|
187 |
+
#component-0 { height: 100%; }
|
188 |
+
#chatbot { flex-grow: 1; overflow: auto;}
|
189 |
+
"""
|
190 |
+
with gr.Blocks(css=CSS) as demo:
|
191 |
+
gr.Markdown("## OneLLM: One Framework to Align All Modalities with Language")
|
192 |
+
with gr.Row(equal_height=True):
|
193 |
+
with gr.Column(scale=1):
|
194 |
+
img_path = gr.Image(label='Image Input', type='filepath')
|
195 |
+
video_path = gr.Video(label='Video Input')
|
196 |
+
audio_path = gr.Audio(label='Audio Input', type='filepath', sources=['upload'])
|
197 |
+
modality = gr.Radio(choices=['image', 'audio', 'video'], value='image', interactive=True, label='Input Modalities')
|
198 |
+
|
199 |
+
with gr.Column(scale=2):
|
200 |
+
chatbot = gr.Chatbot(elem_id="chatbot")
|
201 |
+
msg = gr.Textbox()
|
202 |
+
|
203 |
+
with gr.Row():
|
204 |
+
submit_button = gr.Button("Submit", variant="primary")
|
205 |
+
undo_button = gr.Button("Undo")
|
206 |
+
clear_button = gr.ClearButton([chatbot, msg, img_path, audio_path, video_path, modality])
|
207 |
+
with gr.Row():
|
208 |
+
max_gen_len = gr.Slider(
|
209 |
+
minimum=1, maximum=args.model_max_seq_len // 2,
|
210 |
+
value=args.model_max_seq_len // 2, interactive=True,
|
211 |
+
label="Single-turn max response length",
|
212 |
+
)
|
213 |
+
gen_t = gr.Slider(
|
214 |
+
minimum=0, maximum=1, value=0.1, interactive=True,
|
215 |
+
label="Temperature",
|
216 |
+
)
|
217 |
+
top_p = gr.Slider(
|
218 |
+
minimum=0, maximum=1, value=0.75, interactive=True,
|
219 |
+
label="Top-p",
|
220 |
+
)
|
221 |
+
msg.submit(
|
222 |
+
show_user_input, [msg, chatbot], [msg, chatbot],
|
223 |
+
).then(
|
224 |
+
stream_model_output, [img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot,
|
225 |
+
)
|
226 |
+
submit_button.click(
|
227 |
+
show_user_input, [msg, chatbot], [msg, chatbot],
|
228 |
+
).then(
|
229 |
+
stream_model_output, [img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot,
|
230 |
+
)
|
231 |
+
undo_button.click(undo, chatbot, chatbot)
|
232 |
+
# img_path.change(clear, [], [chatbot, msg])
|
233 |
+
barrier.wait()
|
234 |
+
demo.queue(api_open=True).launch(share=True, max_threads=1)
|
235 |
+
|
236 |
+
|
237 |
+
if __name__ == "__main__":
|
238 |
+
parser = argparse.ArgumentParser("Chat Demo")
|
239 |
+
group = parser.add_mutually_exclusive_group()
|
240 |
+
group.add_argument(
|
241 |
+
"--gpu_ids", type=int, nargs="+",
|
242 |
+
help="A list of space-separated gpu ids to run the model on. "
|
243 |
+
"The model will span across GPUs in tensor-parallel mode."
|
244 |
+
)
|
245 |
+
parser.add_argument(
|
246 |
+
"--tokenizer_path", type=str,
|
247 |
+
help="Path to the tokenizer.model file provided along with the LLaMA "
|
248 |
+
"model."
|
249 |
+
)
|
250 |
+
parser.add_argument(
|
251 |
+
"--llama_type", default="onellm", type=str, metavar="MODEL",
|
252 |
+
help="LLaMA model type."
|
253 |
+
)
|
254 |
+
parser.add_argument(
|
255 |
+
"--llama_config", type=str, required=True,
|
256 |
+
help="Path to the llama model config json."
|
257 |
+
)
|
258 |
+
parser.add_argument(
|
259 |
+
"--model_max_seq_len", type=int, default=2048,
|
260 |
+
help="Max sequence length accepted by the pretrained model."
|
261 |
+
)
|
262 |
+
parser.add_argument(
|
263 |
+
"--pretrained_path", type=str, required=True,
|
264 |
+
help="Path to the llama model checkpoints. A list of checkpoints is "
|
265 |
+
"supported and will be merged from left to right.")
|
266 |
+
parser.add_argument(
|
267 |
+
"--master_port", type=int, default=23862,
|
268 |
+
help="A port used by the PyTorch distributed module to initialize."
|
269 |
+
)
|
270 |
+
parser.add_argument(
|
271 |
+
"--master_addr", type=str, default="127.0.0.1",
|
272 |
+
help="An address used by the PyTorch distributed module to initialize."
|
273 |
+
)
|
274 |
+
parser.add_argument(
|
275 |
+
"--dtype", type=str, choices=["fp16", "bf16"], default="fp16",
|
276 |
+
help="The dtype used for model weights and inference."
|
277 |
+
)
|
278 |
+
args = parser.parse_args()
|
279 |
+
|
280 |
+
# using the default "fork" method messes up some imported libs (e.g.,
|
281 |
+
# pandas)
|
282 |
+
mp.set_start_method("spawn")
|
283 |
+
|
284 |
+
# setup the queues and start the model workers
|
285 |
+
request_queues = []
|
286 |
+
response_queue = mp.Queue()
|
287 |
+
worker_processes = []
|
288 |
+
barrier = mp.Barrier(len(args.gpu_ids) + 1)
|
289 |
+
for rank, gpu_id in enumerate(args.gpu_ids):
|
290 |
+
request_queue = mp.Queue()
|
291 |
+
rank_response_queue = response_queue if rank == 0 else None
|
292 |
+
process = mp.Process(
|
293 |
+
target=model_worker,
|
294 |
+
args=(rank, args, barrier, request_queue, rank_response_queue),
|
295 |
+
)
|
296 |
+
process.start()
|
297 |
+
worker_processes.append(process)
|
298 |
+
request_queues.append(request_queue)
|
299 |
+
|
300 |
+
gradio_worker(request_queues, response_queue, args, barrier)
|
lib/__pycache__/point_utils.cpython-310.pyc
ADDED
Binary file (6.74 kB). View file
|
|
lib/point_utils.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.autograd import Function
|
4 |
+
import pointnet2_cuda
|
5 |
+
|
6 |
+
class KNN(nn.Module):
|
7 |
+
def __init__(self, neighbors, transpose_mode=True):
|
8 |
+
super(KNN, self).__init__()
|
9 |
+
self.neighbors = neighbors
|
10 |
+
|
11 |
+
@torch.no_grad()
|
12 |
+
def forward(self, support, query):
|
13 |
+
"""
|
14 |
+
Args:
|
15 |
+
support ([tensor]): [B, N, C]
|
16 |
+
query ([tensor]): [B, M, C]
|
17 |
+
Returns:
|
18 |
+
[int]: neighbor idx. [B, M, K]
|
19 |
+
"""
|
20 |
+
dist = torch.cdist(support, query)
|
21 |
+
k_dist = dist.topk(k=self.neighbors, dim=1, largest=False)
|
22 |
+
return k_dist.values, k_dist.indices.transpose(1, 2).contiguous().int()
|
23 |
+
|
24 |
+
|
25 |
+
class GroupingOperation(Function):
|
26 |
+
|
27 |
+
@staticmethod
|
28 |
+
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
29 |
+
def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
30 |
+
"""
|
31 |
+
:param ctx:
|
32 |
+
:param features: (B, C, N) tensor of features to group
|
33 |
+
:param idx: (B, npoint, nsample) tensor containing the indicies of features to group with
|
34 |
+
:return:
|
35 |
+
output: (B, C, npoint, nsample) tensor
|
36 |
+
"""
|
37 |
+
assert features.is_contiguous()
|
38 |
+
assert idx.is_contiguous()
|
39 |
+
|
40 |
+
B, nfeatures, nsample = idx.size()
|
41 |
+
_, C, N = features.size()
|
42 |
+
output = torch.cuda.FloatTensor(B, C, nfeatures, nsample, device=features.device)
|
43 |
+
|
44 |
+
pointnet2_cuda.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output)
|
45 |
+
|
46 |
+
ctx.for_backwards = (idx, N)
|
47 |
+
return output
|
48 |
+
|
49 |
+
@staticmethod
|
50 |
+
def backward(ctx, grad_out: torch.Tensor):
|
51 |
+
"""
|
52 |
+
:param ctx:
|
53 |
+
:param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward
|
54 |
+
:return:
|
55 |
+
grad_features: (B, C, N) gradient of the features
|
56 |
+
"""
|
57 |
+
idx, N = ctx.for_backwards
|
58 |
+
|
59 |
+
B, C, npoint, nsample = grad_out.size()
|
60 |
+
grad_features = torch.zeros([B, C, N], dtype=torch.float, device=grad_out.device, requires_grad=True)
|
61 |
+
grad_out_data = grad_out.data.contiguous()
|
62 |
+
pointnet2_cuda.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data)
|
63 |
+
return grad_features, None
|
64 |
+
|
65 |
+
grouping_operation = GroupingOperation.apply
|
66 |
+
|
67 |
+
|
68 |
+
class KNNGroup(nn.Module):
|
69 |
+
def __init__(self, nsample: int,
|
70 |
+
relative_xyz=True,
|
71 |
+
normalize_dp=False,
|
72 |
+
return_only_idx=False,
|
73 |
+
**kwargs
|
74 |
+
):
|
75 |
+
"""[summary]
|
76 |
+
|
77 |
+
Args:
|
78 |
+
nsample (int): maximum number of features to gather in the ball
|
79 |
+
use_xyz (bool, optional): concate xyz. Defaults to True.
|
80 |
+
ret_grouped_xyz (bool, optional): [description]. Defaults to False.
|
81 |
+
normalize_dp (bool, optional): [description]. Defaults to False.
|
82 |
+
"""
|
83 |
+
super().__init__()
|
84 |
+
self.nsample = nsample
|
85 |
+
self.knn = KNN(nsample, transpose_mode=True)
|
86 |
+
self.relative_xyz = relative_xyz
|
87 |
+
self.normalize_dp = normalize_dp
|
88 |
+
self.return_only_idx = return_only_idx
|
89 |
+
|
90 |
+
def forward(self, query_xyz: torch.Tensor, support_xyz: torch.Tensor, features: torch.Tensor = None):
|
91 |
+
"""
|
92 |
+
:param query_xyz: (B, N, 3) xyz coordinates of the features
|
93 |
+
:param support_xyz: (B, npoint, 3) centroids
|
94 |
+
:param features: (B, C, N) descriptors of the features
|
95 |
+
:return:
|
96 |
+
new_features: (B, 3 + C, npoint, nsample)
|
97 |
+
"""
|
98 |
+
_, idx = self.knn(support_xyz, query_xyz)
|
99 |
+
if self.return_only_idx:
|
100 |
+
return idx
|
101 |
+
idx = idx.int()
|
102 |
+
xyz_trans = support_xyz.transpose(1, 2).contiguous()
|
103 |
+
grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
|
104 |
+
if self.relative_xyz:
|
105 |
+
grouped_xyz -= query_xyz.transpose(1, 2).unsqueeze(-1) # relative position
|
106 |
+
if self.normalize_dp:
|
107 |
+
grouped_xyz /= torch.amax(torch.sqrt(torch.sum(grouped_xyz**2, dim=1)), dim=(1, 2)).view(-1, 1, 1, 1)
|
108 |
+
if features is not None:
|
109 |
+
grouped_features = grouping_operation(features, idx)
|
110 |
+
return grouped_xyz, grouped_features
|
111 |
+
else:
|
112 |
+
return grouped_xyz, None
|
113 |
+
|
114 |
+
|
115 |
+
class FurthestPointSampling(Function):
|
116 |
+
@staticmethod
|
117 |
+
def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
|
118 |
+
"""
|
119 |
+
Uses iterative furthest point sampling to select a set of npoint features that have the largest
|
120 |
+
minimum distance
|
121 |
+
:param ctx:
|
122 |
+
:param xyz: (B, N, 3) where N > npoint
|
123 |
+
:param npoint: int, number of features in the sampled set
|
124 |
+
:return:
|
125 |
+
output: (B, npoint) tensor containing the set (idx)
|
126 |
+
"""
|
127 |
+
assert xyz.is_contiguous()
|
128 |
+
|
129 |
+
B, N, _ = xyz.size()
|
130 |
+
# output = torch.cuda.IntTensor(B, npoint, device=xyz.device)
|
131 |
+
# temp = torch.cuda.FloatTensor(B, N, device=xyz.device).fill_(1e10)
|
132 |
+
output = torch.cuda.IntTensor(B, npoint)
|
133 |
+
temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
|
134 |
+
|
135 |
+
pointnet2_cuda.furthest_point_sampling_wrapper(
|
136 |
+
B, N, npoint, xyz, temp, output)
|
137 |
+
return output
|
138 |
+
|
139 |
+
@staticmethod
|
140 |
+
def backward(xyz, a=None):
|
141 |
+
return None, None
|
142 |
+
|
143 |
+
furthest_point_sample = FurthestPointSampling.apply
|
144 |
+
|
145 |
+
|
146 |
+
class PointPatchEmbed(nn.Module):
|
147 |
+
|
148 |
+
def __init__(self,
|
149 |
+
sample_ratio=0.0625,
|
150 |
+
sample_number=1024,
|
151 |
+
group_size=32,
|
152 |
+
in_channels=6,
|
153 |
+
channels=1024,
|
154 |
+
kernel_size=1,
|
155 |
+
stride=1,
|
156 |
+
normalize_dp=False,
|
157 |
+
relative_xyz=True,
|
158 |
+
):
|
159 |
+
super().__init__()
|
160 |
+
self.sample_ratio = sample_ratio
|
161 |
+
self.sample_number = sample_number
|
162 |
+
self.group_size = group_size
|
163 |
+
|
164 |
+
self.sample_fn = furthest_point_sample
|
165 |
+
self.grouper = KNNGroup(self.group_size, relative_xyz=relative_xyz, normalize_dp=normalize_dp)
|
166 |
+
|
167 |
+
self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=kernel_size, stride=stride)
|
168 |
+
|
169 |
+
|
170 |
+
def forward(self, x):
|
171 |
+
# coordinates
|
172 |
+
p = x[:, :, 3:].contiguous()
|
173 |
+
|
174 |
+
B, N, _ = p.shape[:3]
|
175 |
+
# idx = self.sample_fn(p, int(N * self.sample_ratio)).long()
|
176 |
+
idx = self.sample_fn(p, self.sample_number).long()
|
177 |
+
center_p = torch.gather(p, 1, idx.unsqueeze(-1).expand(-1, -1, 3))
|
178 |
+
# query neighbors.
|
179 |
+
_, fj = self.grouper(center_p, p, x.permute(0, 2, 1).contiguous()) # [B, N, 6] -> [B, 6, N] -> [B, 6, 1024, 32]
|
180 |
+
|
181 |
+
# [B, 6, 1024] -> [B, channels, 1024, 1]
|
182 |
+
fj = self.conv1(fj).max(dim=-1, keepdim=True)[0]
|
183 |
+
|
184 |
+
return fj
|
185 |
+
|
186 |
+
|
187 |
+
if __name__ == '__main__':
|
188 |
+
model = PointPatchEmbed(channels=256).cuda()
|
189 |
+
input = torch.rand(4, 16384, 6).cuda()
|
190 |
+
ou = model(input)
|
191 |
+
import pdb;pdb.set_trace()
|
lib/pointnet2/pointnet2_modules.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from . import pointnet2_utils
|
6 |
+
from . import pytorch_utils as pt_utils
|
7 |
+
from typing import List
|
8 |
+
|
9 |
+
|
10 |
+
class _PointnetSAModuleBase(nn.Module):
|
11 |
+
|
12 |
+
def __init__(self):
|
13 |
+
super().__init__()
|
14 |
+
self.npoint = None
|
15 |
+
self.groupers = None
|
16 |
+
self.mlps = None
|
17 |
+
self.pool_method = 'max_pool'
|
18 |
+
|
19 |
+
def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, new_xyz=None) -> (torch.Tensor, torch.Tensor):
|
20 |
+
"""
|
21 |
+
:param xyz: (B, N, 3) tensor of the xyz coordinates of the features
|
22 |
+
:param features: (B, N, C) tensor of the descriptors of the the features
|
23 |
+
:param new_xyz:
|
24 |
+
:return:
|
25 |
+
new_xyz: (B, npoint, 3) tensor of the new features' xyz
|
26 |
+
new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors
|
27 |
+
"""
|
28 |
+
new_features_list = []
|
29 |
+
|
30 |
+
xyz_flipped = xyz.transpose(1, 2).contiguous()
|
31 |
+
if new_xyz is None:
|
32 |
+
new_xyz = pointnet2_utils.gather_operation(
|
33 |
+
xyz_flipped,
|
34 |
+
pointnet2_utils.furthest_point_sample(xyz, self.npoint)
|
35 |
+
).transpose(1, 2).contiguous() if self.npoint is not None else None
|
36 |
+
|
37 |
+
for i in range(len(self.groupers)):
|
38 |
+
new_features = self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample)
|
39 |
+
|
40 |
+
new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
|
41 |
+
if self.pool_method == 'max_pool':
|
42 |
+
new_features = F.max_pool2d(
|
43 |
+
new_features, kernel_size=[1, new_features.size(3)]
|
44 |
+
) # (B, mlp[-1], npoint, 1)
|
45 |
+
elif self.pool_method == 'avg_pool':
|
46 |
+
new_features = F.avg_pool2d(
|
47 |
+
new_features, kernel_size=[1, new_features.size(3)]
|
48 |
+
) # (B, mlp[-1], npoint, 1)
|
49 |
+
else:
|
50 |
+
raise NotImplementedError
|
51 |
+
|
52 |
+
new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint)
|
53 |
+
new_features_list.append(new_features)
|
54 |
+
|
55 |
+
return new_xyz, torch.cat(new_features_list, dim=1)
|
56 |
+
|
57 |
+
|
58 |
+
class PointnetSAModuleMSG(_PointnetSAModuleBase):
|
59 |
+
"""Pointnet set abstraction layer with multiscale grouping"""
|
60 |
+
|
61 |
+
def __init__(self, *, npoint: int, radii: List[float], nsamples: List[int], mlps: List[List[int]], bn: bool = True,
|
62 |
+
use_xyz: bool = True, pool_method='max_pool', instance_norm=False):
|
63 |
+
"""
|
64 |
+
:param npoint: int
|
65 |
+
:param radii: list of float, list of radii to group with
|
66 |
+
:param nsamples: list of int, number of samples in each ball query
|
67 |
+
:param mlps: list of list of int, spec of the pointnet before the global pooling for each scale
|
68 |
+
:param bn: whether to use batchnorm
|
69 |
+
:param use_xyz:
|
70 |
+
:param pool_method: max_pool / avg_pool
|
71 |
+
:param instance_norm: whether to use instance_norm
|
72 |
+
"""
|
73 |
+
super().__init__()
|
74 |
+
|
75 |
+
assert len(radii) == len(nsamples) == len(mlps)
|
76 |
+
|
77 |
+
self.npoint = npoint
|
78 |
+
self.groupers = nn.ModuleList()
|
79 |
+
self.mlps = nn.ModuleList()
|
80 |
+
for i in range(len(radii)):
|
81 |
+
radius = radii[i]
|
82 |
+
nsample = nsamples[i]
|
83 |
+
self.groupers.append(
|
84 |
+
pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
|
85 |
+
if npoint is not None else pointnet2_utils.GroupAll(use_xyz)
|
86 |
+
)
|
87 |
+
mlp_spec = mlps[i]
|
88 |
+
if use_xyz:
|
89 |
+
mlp_spec[0] += 3
|
90 |
+
|
91 |
+
self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn, instance_norm=instance_norm))
|
92 |
+
self.pool_method = pool_method
|
93 |
+
|
94 |
+
|
95 |
+
class PointnetSAModule(PointnetSAModuleMSG):
|
96 |
+
"""Pointnet set abstraction layer"""
|
97 |
+
|
98 |
+
def __init__(self, *, mlp: List[int], npoint: int = None, radius: float = None, nsample: int = None,
|
99 |
+
bn: bool = True, use_xyz: bool = True, pool_method='max_pool', instance_norm=False):
|
100 |
+
"""
|
101 |
+
:param mlp: list of int, spec of the pointnet before the global max_pool
|
102 |
+
:param npoint: int, number of features
|
103 |
+
:param radius: float, radius of ball
|
104 |
+
:param nsample: int, number of samples in the ball query
|
105 |
+
:param bn: whether to use batchnorm
|
106 |
+
:param use_xyz:
|
107 |
+
:param pool_method: max_pool / avg_pool
|
108 |
+
:param instance_norm: whether to use instance_norm
|
109 |
+
"""
|
110 |
+
super().__init__(
|
111 |
+
mlps=[mlp], npoint=npoint, radii=[radius], nsamples=[nsample], bn=bn, use_xyz=use_xyz,
|
112 |
+
pool_method=pool_method, instance_norm=instance_norm
|
113 |
+
)
|
114 |
+
|
115 |
+
|
116 |
+
class PointnetFPModule(nn.Module):
|
117 |
+
r"""Propigates the features of one set to another"""
|
118 |
+
|
119 |
+
def __init__(self, *, mlp: List[int], bn: bool = True):
|
120 |
+
"""
|
121 |
+
:param mlp: list of int
|
122 |
+
:param bn: whether to use batchnorm
|
123 |
+
"""
|
124 |
+
super().__init__()
|
125 |
+
self.mlp = pt_utils.SharedMLP(mlp, bn=bn)
|
126 |
+
|
127 |
+
def forward(
|
128 |
+
self, unknown: torch.Tensor, known: torch.Tensor, unknow_feats: torch.Tensor, known_feats: torch.Tensor
|
129 |
+
) -> torch.Tensor:
|
130 |
+
"""
|
131 |
+
:param unknown: (B, n, 3) tensor of the xyz positions of the unknown features
|
132 |
+
:param known: (B, m, 3) tensor of the xyz positions of the known features
|
133 |
+
:param unknow_feats: (B, C1, n) tensor of the features to be propigated to
|
134 |
+
:param known_feats: (B, C2, m) tensor of features to be propigated
|
135 |
+
:return:
|
136 |
+
new_features: (B, mlp[-1], n) tensor of the features of the unknown features
|
137 |
+
"""
|
138 |
+
if known is not None:
|
139 |
+
dist, idx = pointnet2_utils.three_nn(unknown, known)
|
140 |
+
dist_recip = 1.0 / (dist + 1e-8)
|
141 |
+
norm = torch.sum(dist_recip, dim=2, keepdim=True)
|
142 |
+
weight = dist_recip / norm
|
143 |
+
|
144 |
+
interpolated_feats = pointnet2_utils.three_interpolate(known_feats, idx, weight)
|
145 |
+
else:
|
146 |
+
interpolated_feats = known_feats.expand(*known_feats.size()[0:2], unknown.size(1))
|
147 |
+
|
148 |
+
if unknow_feats is not None:
|
149 |
+
new_features = torch.cat([interpolated_feats, unknow_feats], dim=1) # (B, C2 + C1, n)
|
150 |
+
else:
|
151 |
+
new_features = interpolated_feats
|
152 |
+
|
153 |
+
new_features = new_features.unsqueeze(-1)
|
154 |
+
new_features = self.mlp(new_features)
|
155 |
+
|
156 |
+
return new_features.squeeze(-1)
|
157 |
+
|
158 |
+
|
159 |
+
if __name__ == "__main__":
|
160 |
+
pass
|
lib/pointnet2/pointnet2_utils.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.autograd import Variable
|
3 |
+
from torch.autograd import Function
|
4 |
+
import torch.nn as nn
|
5 |
+
from typing import Tuple
|
6 |
+
|
7 |
+
import pointnet2_cuda as pointnet2
|
8 |
+
|
9 |
+
|
10 |
+
class FurthestPointSampling(Function):
|
11 |
+
@staticmethod
|
12 |
+
def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
|
13 |
+
"""
|
14 |
+
Uses iterative furthest point sampling to select a set of npoint features that have the largest
|
15 |
+
minimum distance
|
16 |
+
:param ctx:
|
17 |
+
:param xyz: (B, N, 3) where N > npoint
|
18 |
+
:param npoint: int, number of features in the sampled set
|
19 |
+
:return:
|
20 |
+
output: (B, npoint) tensor containing the set
|
21 |
+
"""
|
22 |
+
assert xyz.is_contiguous()
|
23 |
+
|
24 |
+
B, N, _ = xyz.size()
|
25 |
+
output = torch.cuda.IntTensor(B, npoint)
|
26 |
+
temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
|
27 |
+
|
28 |
+
pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
|
29 |
+
return output
|
30 |
+
|
31 |
+
@staticmethod
|
32 |
+
def backward(xyz, a=None):
|
33 |
+
return None, None
|
34 |
+
|
35 |
+
|
36 |
+
furthest_point_sample = FurthestPointSampling.apply
|
37 |
+
|
38 |
+
|
39 |
+
class GatherOperation(Function):
|
40 |
+
|
41 |
+
@staticmethod
|
42 |
+
def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
43 |
+
"""
|
44 |
+
:param ctx:
|
45 |
+
:param features: (B, C, N)
|
46 |
+
:param idx: (B, npoint) index tensor of the features to gather
|
47 |
+
:return:
|
48 |
+
output: (B, C, npoint)
|
49 |
+
"""
|
50 |
+
assert features.is_contiguous()
|
51 |
+
assert idx.is_contiguous()
|
52 |
+
|
53 |
+
B, npoint = idx.size()
|
54 |
+
_, C, N = features.size()
|
55 |
+
output = torch.cuda.FloatTensor(B, C, npoint)
|
56 |
+
|
57 |
+
pointnet2.gather_points_wrapper(B, C, N, npoint, features, idx, output)
|
58 |
+
|
59 |
+
ctx.for_backwards = (idx, C, N)
|
60 |
+
return output
|
61 |
+
|
62 |
+
@staticmethod
|
63 |
+
def backward(ctx, grad_out):
|
64 |
+
idx, C, N = ctx.for_backwards
|
65 |
+
B, npoint = idx.size()
|
66 |
+
|
67 |
+
grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
|
68 |
+
grad_out_data = grad_out.data.contiguous()
|
69 |
+
pointnet2.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data)
|
70 |
+
return grad_features, None
|
71 |
+
|
72 |
+
|
73 |
+
gather_operation = GatherOperation.apply
|
74 |
+
|
75 |
+
|
76 |
+
class ThreeNN(Function):
|
77 |
+
|
78 |
+
@staticmethod
|
79 |
+
def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
80 |
+
"""
|
81 |
+
Find the three nearest neighbors of unknown in known
|
82 |
+
:param ctx:
|
83 |
+
:param unknown: (B, N, 3)
|
84 |
+
:param known: (B, M, 3)
|
85 |
+
:return:
|
86 |
+
dist: (B, N, 3) l2 distance to the three nearest neighbors
|
87 |
+
idx: (B, N, 3) index of 3 nearest neighbors
|
88 |
+
"""
|
89 |
+
assert unknown.is_contiguous()
|
90 |
+
assert known.is_contiguous()
|
91 |
+
|
92 |
+
B, N, _ = unknown.size()
|
93 |
+
m = known.size(1)
|
94 |
+
dist2 = torch.cuda.FloatTensor(B, N, 3)
|
95 |
+
idx = torch.cuda.IntTensor(B, N, 3)
|
96 |
+
|
97 |
+
pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx)
|
98 |
+
return torch.sqrt(dist2), idx
|
99 |
+
|
100 |
+
@staticmethod
|
101 |
+
def backward(ctx, a=None, b=None):
|
102 |
+
return None, None
|
103 |
+
|
104 |
+
|
105 |
+
three_nn = ThreeNN.apply
|
106 |
+
|
107 |
+
|
108 |
+
class ThreeInterpolate(Function):
|
109 |
+
|
110 |
+
@staticmethod
|
111 |
+
def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
112 |
+
"""
|
113 |
+
Performs weight linear interpolation on 3 features
|
114 |
+
:param ctx:
|
115 |
+
:param features: (B, C, M) Features descriptors to be interpolated from
|
116 |
+
:param idx: (B, n, 3) three nearest neighbors of the target features in features
|
117 |
+
:param weight: (B, n, 3) weights
|
118 |
+
:return:
|
119 |
+
output: (B, C, N) tensor of the interpolated features
|
120 |
+
"""
|
121 |
+
assert features.is_contiguous()
|
122 |
+
assert idx.is_contiguous()
|
123 |
+
assert weight.is_contiguous()
|
124 |
+
|
125 |
+
B, c, m = features.size()
|
126 |
+
n = idx.size(1)
|
127 |
+
ctx.three_interpolate_for_backward = (idx, weight, m)
|
128 |
+
output = torch.cuda.FloatTensor(B, c, n)
|
129 |
+
|
130 |
+
pointnet2.three_interpolate_wrapper(B, c, m, n, features, idx, weight, output)
|
131 |
+
return output
|
132 |
+
|
133 |
+
@staticmethod
|
134 |
+
def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
135 |
+
"""
|
136 |
+
:param ctx:
|
137 |
+
:param grad_out: (B, C, N) tensor with gradients of outputs
|
138 |
+
:return:
|
139 |
+
grad_features: (B, C, M) tensor with gradients of features
|
140 |
+
None:
|
141 |
+
None:
|
142 |
+
"""
|
143 |
+
idx, weight, m = ctx.three_interpolate_for_backward
|
144 |
+
B, c, n = grad_out.size()
|
145 |
+
|
146 |
+
grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_())
|
147 |
+
grad_out_data = grad_out.data.contiguous()
|
148 |
+
|
149 |
+
pointnet2.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data, idx, weight, grad_features.data)
|
150 |
+
return grad_features, None, None
|
151 |
+
|
152 |
+
|
153 |
+
three_interpolate = ThreeInterpolate.apply
|
154 |
+
|
155 |
+
|
156 |
+
class GroupingOperation(Function):
|
157 |
+
|
158 |
+
@staticmethod
|
159 |
+
def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
160 |
+
"""
|
161 |
+
:param ctx:
|
162 |
+
:param features: (B, C, N) tensor of features to group
|
163 |
+
:param idx: (B, npoint, nsample) tensor containing the indicies of features to group with
|
164 |
+
:return:
|
165 |
+
output: (B, C, npoint, nsample) tensor
|
166 |
+
"""
|
167 |
+
assert features.is_contiguous()
|
168 |
+
assert idx.is_contiguous()
|
169 |
+
|
170 |
+
B, nfeatures, nsample = idx.size()
|
171 |
+
_, C, N = features.size()
|
172 |
+
output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
|
173 |
+
|
174 |
+
pointnet2.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output)
|
175 |
+
|
176 |
+
ctx.for_backwards = (idx, N)
|
177 |
+
return output
|
178 |
+
|
179 |
+
@staticmethod
|
180 |
+
def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
181 |
+
"""
|
182 |
+
:param ctx:
|
183 |
+
:param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward
|
184 |
+
:return:
|
185 |
+
grad_features: (B, C, N) gradient of the features
|
186 |
+
"""
|
187 |
+
idx, N = ctx.for_backwards
|
188 |
+
|
189 |
+
B, C, npoint, nsample = grad_out.size()
|
190 |
+
grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
|
191 |
+
|
192 |
+
grad_out_data = grad_out.data.contiguous()
|
193 |
+
pointnet2.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data)
|
194 |
+
return grad_features, None
|
195 |
+
|
196 |
+
|
197 |
+
grouping_operation = GroupingOperation.apply
|
198 |
+
|
199 |
+
|
200 |
+
class BallQuery(Function):
|
201 |
+
|
202 |
+
@staticmethod
|
203 |
+
def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor:
|
204 |
+
"""
|
205 |
+
:param ctx:
|
206 |
+
:param radius: float, radius of the balls
|
207 |
+
:param nsample: int, maximum number of features in the balls
|
208 |
+
:param xyz: (B, N, 3) xyz coordinates of the features
|
209 |
+
:param new_xyz: (B, npoint, 3) centers of the ball query
|
210 |
+
:return:
|
211 |
+
idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls
|
212 |
+
"""
|
213 |
+
assert new_xyz.is_contiguous()
|
214 |
+
assert xyz.is_contiguous()
|
215 |
+
|
216 |
+
B, N, _ = xyz.size()
|
217 |
+
npoint = new_xyz.size(1)
|
218 |
+
idx = torch.cuda.IntTensor(B, npoint, nsample).zero_()
|
219 |
+
|
220 |
+
pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx)
|
221 |
+
return idx
|
222 |
+
|
223 |
+
@staticmethod
|
224 |
+
def backward(ctx, a=None):
|
225 |
+
return None, None, None, None
|
226 |
+
|
227 |
+
|
228 |
+
ball_query = BallQuery.apply
|
229 |
+
|
230 |
+
|
231 |
+
class QueryAndGroup(nn.Module):
|
232 |
+
def __init__(self, radius: float, nsample: int, use_xyz: bool = True):
|
233 |
+
"""
|
234 |
+
:param radius: float, radius of ball
|
235 |
+
:param nsample: int, maximum number of features to gather in the ball
|
236 |
+
:param use_xyz:
|
237 |
+
"""
|
238 |
+
super().__init__()
|
239 |
+
self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
|
240 |
+
|
241 |
+
def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None) -> Tuple[torch.Tensor]:
|
242 |
+
"""
|
243 |
+
:param xyz: (B, N, 3) xyz coordinates of the features
|
244 |
+
:param new_xyz: (B, npoint, 3) centroids
|
245 |
+
:param features: (B, C, N) descriptors of the features
|
246 |
+
:return:
|
247 |
+
new_features: (B, 3 + C, npoint, nsample)
|
248 |
+
"""
|
249 |
+
idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
|
250 |
+
xyz_trans = xyz.transpose(1, 2).contiguous()
|
251 |
+
grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
|
252 |
+
grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
|
253 |
+
|
254 |
+
if features is not None:
|
255 |
+
grouped_features = grouping_operation(features, idx)
|
256 |
+
if self.use_xyz:
|
257 |
+
new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, C + 3, npoint, nsample)
|
258 |
+
else:
|
259 |
+
new_features = grouped_features
|
260 |
+
else:
|
261 |
+
assert self.use_xyz, "Cannot have not features and not use xyz as a feature!"
|
262 |
+
new_features = grouped_xyz
|
263 |
+
|
264 |
+
return new_features
|
265 |
+
|
266 |
+
|
267 |
+
class GroupAll(nn.Module):
|
268 |
+
def __init__(self, use_xyz: bool = True):
|
269 |
+
super().__init__()
|
270 |
+
self.use_xyz = use_xyz
|
271 |
+
|
272 |
+
def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None):
|
273 |
+
"""
|
274 |
+
:param xyz: (B, N, 3) xyz coordinates of the features
|
275 |
+
:param new_xyz: ignored
|
276 |
+
:param features: (B, C, N) descriptors of the features
|
277 |
+
:return:
|
278 |
+
new_features: (B, C + 3, 1, N)
|
279 |
+
"""
|
280 |
+
grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
|
281 |
+
if features is not None:
|
282 |
+
grouped_features = features.unsqueeze(2)
|
283 |
+
if self.use_xyz:
|
284 |
+
new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, 3 + C, 1, N)
|
285 |
+
else:
|
286 |
+
new_features = grouped_features
|
287 |
+
else:
|
288 |
+
new_features = grouped_xyz
|
289 |
+
|
290 |
+
return new_features
|
lib/pointnet2/pytorch_utils.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from typing import List, Tuple
|
3 |
+
|
4 |
+
|
5 |
+
class SharedMLP(nn.Sequential):
|
6 |
+
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
args: List[int],
|
10 |
+
*,
|
11 |
+
bn: bool = False,
|
12 |
+
activation=nn.ReLU(inplace=True),
|
13 |
+
preact: bool = False,
|
14 |
+
first: bool = False,
|
15 |
+
name: str = "",
|
16 |
+
instance_norm: bool = False,
|
17 |
+
):
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
for i in range(len(args) - 1):
|
21 |
+
self.add_module(
|
22 |
+
name + 'layer{}'.format(i),
|
23 |
+
Conv2d(
|
24 |
+
args[i],
|
25 |
+
args[i + 1],
|
26 |
+
bn=(not first or not preact or (i != 0)) and bn,
|
27 |
+
activation=activation
|
28 |
+
if (not first or not preact or (i != 0)) else None,
|
29 |
+
preact=preact,
|
30 |
+
instance_norm=instance_norm
|
31 |
+
)
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
class _ConvBase(nn.Sequential):
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
in_size,
|
40 |
+
out_size,
|
41 |
+
kernel_size,
|
42 |
+
stride,
|
43 |
+
padding,
|
44 |
+
activation,
|
45 |
+
bn,
|
46 |
+
init,
|
47 |
+
conv=None,
|
48 |
+
batch_norm=None,
|
49 |
+
bias=True,
|
50 |
+
preact=False,
|
51 |
+
name="",
|
52 |
+
instance_norm=False,
|
53 |
+
instance_norm_func=None
|
54 |
+
):
|
55 |
+
super().__init__()
|
56 |
+
|
57 |
+
bias = bias and (not bn)
|
58 |
+
conv_unit = conv(
|
59 |
+
in_size,
|
60 |
+
out_size,
|
61 |
+
kernel_size=kernel_size,
|
62 |
+
stride=stride,
|
63 |
+
padding=padding,
|
64 |
+
bias=bias
|
65 |
+
)
|
66 |
+
init(conv_unit.weight)
|
67 |
+
if bias:
|
68 |
+
nn.init.constant_(conv_unit.bias, 0)
|
69 |
+
|
70 |
+
if bn:
|
71 |
+
if not preact:
|
72 |
+
bn_unit = batch_norm(out_size)
|
73 |
+
else:
|
74 |
+
bn_unit = batch_norm(in_size)
|
75 |
+
if instance_norm:
|
76 |
+
if not preact:
|
77 |
+
in_unit = instance_norm_func(out_size, affine=False, track_running_stats=False)
|
78 |
+
else:
|
79 |
+
in_unit = instance_norm_func(in_size, affine=False, track_running_stats=False)
|
80 |
+
|
81 |
+
if preact:
|
82 |
+
if bn:
|
83 |
+
self.add_module(name + 'bn', bn_unit)
|
84 |
+
|
85 |
+
if activation is not None:
|
86 |
+
self.add_module(name + 'activation', activation)
|
87 |
+
|
88 |
+
if not bn and instance_norm:
|
89 |
+
self.add_module(name + 'in', in_unit)
|
90 |
+
|
91 |
+
self.add_module(name + 'conv', conv_unit)
|
92 |
+
|
93 |
+
if not preact:
|
94 |
+
if bn:
|
95 |
+
self.add_module(name + 'bn', bn_unit)
|
96 |
+
|
97 |
+
if activation is not None:
|
98 |
+
self.add_module(name + 'activation', activation)
|
99 |
+
|
100 |
+
if not bn and instance_norm:
|
101 |
+
self.add_module(name + 'in', in_unit)
|
102 |
+
|
103 |
+
|
104 |
+
class _BNBase(nn.Sequential):
|
105 |
+
|
106 |
+
def __init__(self, in_size, batch_norm=None, name=""):
|
107 |
+
super().__init__()
|
108 |
+
self.add_module(name + "bn", batch_norm(in_size))
|
109 |
+
|
110 |
+
nn.init.constant_(self[0].weight, 1.0)
|
111 |
+
nn.init.constant_(self[0].bias, 0)
|
112 |
+
|
113 |
+
|
114 |
+
class BatchNorm1d(_BNBase):
|
115 |
+
|
116 |
+
def __init__(self, in_size: int, *, name: str = ""):
|
117 |
+
super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name)
|
118 |
+
|
119 |
+
|
120 |
+
class BatchNorm2d(_BNBase):
|
121 |
+
|
122 |
+
def __init__(self, in_size: int, name: str = ""):
|
123 |
+
super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name)
|
124 |
+
|
125 |
+
|
126 |
+
class Conv1d(_ConvBase):
|
127 |
+
|
128 |
+
def __init__(
|
129 |
+
self,
|
130 |
+
in_size: int,
|
131 |
+
out_size: int,
|
132 |
+
*,
|
133 |
+
kernel_size: int = 1,
|
134 |
+
stride: int = 1,
|
135 |
+
padding: int = 0,
|
136 |
+
activation=nn.ReLU(inplace=True),
|
137 |
+
bn: bool = False,
|
138 |
+
init=nn.init.kaiming_normal_,
|
139 |
+
bias: bool = True,
|
140 |
+
preact: bool = False,
|
141 |
+
name: str = "",
|
142 |
+
instance_norm=False
|
143 |
+
):
|
144 |
+
super().__init__(
|
145 |
+
in_size,
|
146 |
+
out_size,
|
147 |
+
kernel_size,
|
148 |
+
stride,
|
149 |
+
padding,
|
150 |
+
activation,
|
151 |
+
bn,
|
152 |
+
init,
|
153 |
+
conv=nn.Conv1d,
|
154 |
+
batch_norm=BatchNorm1d,
|
155 |
+
bias=bias,
|
156 |
+
preact=preact,
|
157 |
+
name=name,
|
158 |
+
instance_norm=instance_norm,
|
159 |
+
instance_norm_func=nn.InstanceNorm1d
|
160 |
+
)
|
161 |
+
|
162 |
+
|
163 |
+
class Conv2d(_ConvBase):
|
164 |
+
|
165 |
+
def __init__(
|
166 |
+
self,
|
167 |
+
in_size: int,
|
168 |
+
out_size: int,
|
169 |
+
*,
|
170 |
+
kernel_size: Tuple[int, int] = (1, 1),
|
171 |
+
stride: Tuple[int, int] = (1, 1),
|
172 |
+
padding: Tuple[int, int] = (0, 0),
|
173 |
+
activation=nn.ReLU(inplace=True),
|
174 |
+
bn: bool = False,
|
175 |
+
init=nn.init.kaiming_normal_,
|
176 |
+
bias: bool = True,
|
177 |
+
preact: bool = False,
|
178 |
+
name: str = "",
|
179 |
+
instance_norm=False
|
180 |
+
):
|
181 |
+
super().__init__(
|
182 |
+
in_size,
|
183 |
+
out_size,
|
184 |
+
kernel_size,
|
185 |
+
stride,
|
186 |
+
padding,
|
187 |
+
activation,
|
188 |
+
bn,
|
189 |
+
init,
|
190 |
+
conv=nn.Conv2d,
|
191 |
+
batch_norm=BatchNorm2d,
|
192 |
+
bias=bias,
|
193 |
+
preact=preact,
|
194 |
+
name=name,
|
195 |
+
instance_norm=instance_norm,
|
196 |
+
instance_norm_func=nn.InstanceNorm2d
|
197 |
+
)
|
198 |
+
|
199 |
+
|
200 |
+
class FC(nn.Sequential):
|
201 |
+
|
202 |
+
def __init__(
|
203 |
+
self,
|
204 |
+
in_size: int,
|
205 |
+
out_size: int,
|
206 |
+
*,
|
207 |
+
activation=nn.ReLU(inplace=True),
|
208 |
+
bn: bool = False,
|
209 |
+
init=None,
|
210 |
+
preact: bool = False,
|
211 |
+
name: str = ""
|
212 |
+
):
|
213 |
+
super().__init__()
|
214 |
+
|
215 |
+
fc = nn.Linear(in_size, out_size, bias=not bn)
|
216 |
+
if init is not None:
|
217 |
+
init(fc.weight)
|
218 |
+
if not bn:
|
219 |
+
nn.init.constant(fc.bias, 0)
|
220 |
+
|
221 |
+
if preact:
|
222 |
+
if bn:
|
223 |
+
self.add_module(name + 'bn', BatchNorm1d(in_size))
|
224 |
+
|
225 |
+
if activation is not None:
|
226 |
+
self.add_module(name + 'activation', activation)
|
227 |
+
|
228 |
+
self.add_module(name + 'fc', fc)
|
229 |
+
|
230 |
+
if not preact:
|
231 |
+
if bn:
|
232 |
+
self.add_module(name + 'bn', BatchNorm1d(out_size))
|
233 |
+
|
234 |
+
if activation is not None:
|
235 |
+
self.add_module(name + 'activation', activation)
|
236 |
+
|
lib/pointnet2/setup.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup
|
2 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
3 |
+
|
4 |
+
setup(
|
5 |
+
name='pointnet2',
|
6 |
+
ext_modules=[
|
7 |
+
CUDAExtension('pointnet2_cuda', [
|
8 |
+
'src/pointnet2_api.cpp',
|
9 |
+
|
10 |
+
'src/ball_query.cpp',
|
11 |
+
'src/ball_query_gpu.cu',
|
12 |
+
'src/group_points.cpp',
|
13 |
+
'src/group_points_gpu.cu',
|
14 |
+
'src/interpolate.cpp',
|
15 |
+
'src/interpolate_gpu.cu',
|
16 |
+
'src/sampling.cpp',
|
17 |
+
'src/sampling_gpu.cu',
|
18 |
+
],
|
19 |
+
extra_compile_args={'cxx': ['-g'],
|
20 |
+
'nvcc': ['-O2']})
|
21 |
+
],
|
22 |
+
cmdclass={'build_ext': BuildExtension}
|
23 |
+
)
|
lib/pointnet2/src/ball_query.cpp
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/serialize/tensor.h>
|
2 |
+
#include <vector>
|
3 |
+
#include <ATen/cuda/CUDAContext.h>
|
4 |
+
#include <ATen/cuda/CUDAEvent.h>
|
5 |
+
#include <cuda.h>
|
6 |
+
#include <cuda_runtime_api.h>
|
7 |
+
#include "ball_query_gpu.h"
|
8 |
+
|
9 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
|
10 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
|
11 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
|
12 |
+
|
13 |
+
int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample,
|
14 |
+
at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor) {
|
15 |
+
CHECK_INPUT(new_xyz_tensor);
|
16 |
+
CHECK_INPUT(xyz_tensor);
|
17 |
+
const float *new_xyz = new_xyz_tensor.data<float>();
|
18 |
+
const float *xyz = xyz_tensor.data<float>();
|
19 |
+
int *idx = idx_tensor.data<int>();
|
20 |
+
|
21 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
22 |
+
ball_query_kernel_launcher_fast(b, n, m, radius, nsample, new_xyz, xyz, idx, stream);
|
23 |
+
return 1;
|
24 |
+
}
|
lib/pointnet2/src/ball_query_gpu.cu
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <math.h>
|
2 |
+
#include <stdio.h>
|
3 |
+
#include <stdlib.h>
|
4 |
+
|
5 |
+
#include "ball_query_gpu.h"
|
6 |
+
#include "cuda_utils.h"
|
7 |
+
|
8 |
+
|
9 |
+
__global__ void ball_query_kernel_fast(int b, int n, int m, float radius, int nsample,
|
10 |
+
const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) {
|
11 |
+
// new_xyz: (B, M, 3)
|
12 |
+
// xyz: (B, N, 3)
|
13 |
+
// output:
|
14 |
+
// idx: (B, M, nsample)
|
15 |
+
int bs_idx = blockIdx.y;
|
16 |
+
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
17 |
+
if (bs_idx >= b || pt_idx >= m) return;
|
18 |
+
|
19 |
+
new_xyz += bs_idx * m * 3 + pt_idx * 3;
|
20 |
+
xyz += bs_idx * n * 3;
|
21 |
+
idx += bs_idx * m * nsample + pt_idx * nsample;
|
22 |
+
|
23 |
+
float radius2 = radius * radius;
|
24 |
+
float new_x = new_xyz[0];
|
25 |
+
float new_y = new_xyz[1];
|
26 |
+
float new_z = new_xyz[2];
|
27 |
+
|
28 |
+
int cnt = 0;
|
29 |
+
for (int k = 0; k < n; ++k) {
|
30 |
+
float x = xyz[k * 3 + 0];
|
31 |
+
float y = xyz[k * 3 + 1];
|
32 |
+
float z = xyz[k * 3 + 2];
|
33 |
+
float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z);
|
34 |
+
if (d2 < radius2){
|
35 |
+
if (cnt == 0){
|
36 |
+
for (int l = 0; l < nsample; ++l) {
|
37 |
+
idx[l] = k;
|
38 |
+
}
|
39 |
+
}
|
40 |
+
idx[cnt] = k;
|
41 |
+
++cnt;
|
42 |
+
if (cnt >= nsample) break;
|
43 |
+
}
|
44 |
+
}
|
45 |
+
}
|
46 |
+
|
47 |
+
|
48 |
+
void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, \
|
49 |
+
const float *new_xyz, const float *xyz, int *idx, cudaStream_t stream) {
|
50 |
+
// new_xyz: (B, M, 3)
|
51 |
+
// xyz: (B, N, 3)
|
52 |
+
// output:
|
53 |
+
// idx: (B, M, nsample)
|
54 |
+
|
55 |
+
cudaError_t err;
|
56 |
+
|
57 |
+
dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row)
|
58 |
+
dim3 threads(THREADS_PER_BLOCK);
|
59 |
+
|
60 |
+
ball_query_kernel_fast<<<blocks, threads, 0, stream>>>(b, n, m, radius, nsample, new_xyz, xyz, idx);
|
61 |
+
// cudaDeviceSynchronize(); // for using printf in kernel function
|
62 |
+
err = cudaGetLastError();
|
63 |
+
if (cudaSuccess != err) {
|
64 |
+
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
65 |
+
exit(-1);
|
66 |
+
}
|
67 |
+
}
|
lib/pointnet2/src/ball_query_gpu.h
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _BALL_QUERY_GPU_H
|
2 |
+
#define _BALL_QUERY_GPU_H
|
3 |
+
|
4 |
+
#include <torch/serialize/tensor.h>
|
5 |
+
#include <vector>
|
6 |
+
#include <cuda.h>
|
7 |
+
#include <cuda_runtime_api.h>
|
8 |
+
|
9 |
+
int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample,
|
10 |
+
at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor);
|
11 |
+
|
12 |
+
void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample,
|
13 |
+
const float *xyz, const float *new_xyz, int *idx, cudaStream_t stream);
|
14 |
+
|
15 |
+
#endif
|
lib/pointnet2/src/cuda_utils.h
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _CUDA_UTILS_H
|
2 |
+
#define _CUDA_UTILS_H
|
3 |
+
|
4 |
+
#include <cmath>
|
5 |
+
|
6 |
+
#define TOTAL_THREADS 1024
|
7 |
+
#define THREADS_PER_BLOCK 256
|
8 |
+
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
|
9 |
+
|
10 |
+
inline int opt_n_threads(int work_size) {
|
11 |
+
const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);
|
12 |
+
|
13 |
+
return max(min(1 << pow_2, TOTAL_THREADS), 1);
|
14 |
+
}
|
15 |
+
#endif
|
lib/pointnet2/src/group_points.cpp
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/serialize/tensor.h>
|
2 |
+
#include <cuda.h>
|
3 |
+
#include <cuda_runtime_api.h>
|
4 |
+
#include <vector>
|
5 |
+
#include "group_points_gpu.h"
|
6 |
+
#include <ATen/cuda/CUDAContext.h>
|
7 |
+
#include <ATen/cuda/CUDAEvent.h>
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample,
|
12 |
+
at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) {
|
13 |
+
|
14 |
+
float *grad_points = grad_points_tensor.data<float>();
|
15 |
+
const int *idx = idx_tensor.data<int>();
|
16 |
+
const float *grad_out = grad_out_tensor.data<float>();
|
17 |
+
|
18 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
19 |
+
group_points_grad_kernel_launcher_fast(b, c, n, npoints, nsample, grad_out, idx, grad_points, stream);
|
20 |
+
return 1;
|
21 |
+
}
|
22 |
+
|
23 |
+
|
24 |
+
int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample,
|
25 |
+
at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor) {
|
26 |
+
|
27 |
+
const float *points = points_tensor.data<float>();
|
28 |
+
const int *idx = idx_tensor.data<int>();
|
29 |
+
float *out = out_tensor.data<float>();
|
30 |
+
|
31 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
32 |
+
group_points_kernel_launcher_fast(b, c, n, npoints, nsample, points, idx, out, stream);
|
33 |
+
return 1;
|
34 |
+
}
|
lib/pointnet2/src/group_points_gpu.cu
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <stdio.h>
|
2 |
+
#include <stdlib.h>
|
3 |
+
|
4 |
+
#include "cuda_utils.h"
|
5 |
+
#include "group_points_gpu.h"
|
6 |
+
|
7 |
+
|
8 |
+
__global__ void group_points_grad_kernel_fast(int b, int c, int n, int npoints, int nsample,
|
9 |
+
const float *__restrict__ grad_out, const int *__restrict__ idx, float *__restrict__ grad_points) {
|
10 |
+
// grad_out: (B, C, npoints, nsample)
|
11 |
+
// idx: (B, npoints, nsample)
|
12 |
+
// output:
|
13 |
+
// grad_points: (B, C, N)
|
14 |
+
int bs_idx = blockIdx.z;
|
15 |
+
int c_idx = blockIdx.y;
|
16 |
+
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
17 |
+
int pt_idx = index / nsample;
|
18 |
+
if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;
|
19 |
+
|
20 |
+
int sample_idx = index % nsample;
|
21 |
+
grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx;
|
22 |
+
idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx;
|
23 |
+
|
24 |
+
atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0] , grad_out[0]);
|
25 |
+
}
|
26 |
+
|
27 |
+
void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample,
|
28 |
+
const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) {
|
29 |
+
// grad_out: (B, C, npoints, nsample)
|
30 |
+
// idx: (B, npoints, nsample)
|
31 |
+
// output:
|
32 |
+
// grad_points: (B, C, N)
|
33 |
+
cudaError_t err;
|
34 |
+
dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
|
35 |
+
dim3 threads(THREADS_PER_BLOCK);
|
36 |
+
|
37 |
+
group_points_grad_kernel_fast<<<blocks, threads, 0, stream>>>(b, c, n, npoints, nsample, grad_out, idx, grad_points);
|
38 |
+
|
39 |
+
err = cudaGetLastError();
|
40 |
+
if (cudaSuccess != err) {
|
41 |
+
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
42 |
+
exit(-1);
|
43 |
+
}
|
44 |
+
}
|
45 |
+
|
46 |
+
|
47 |
+
__global__ void group_points_kernel_fast(int b, int c, int n, int npoints, int nsample,
|
48 |
+
const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) {
|
49 |
+
// points: (B, C, N)
|
50 |
+
// idx: (B, npoints, nsample)
|
51 |
+
// output:
|
52 |
+
// out: (B, C, npoints, nsample)
|
53 |
+
int bs_idx = blockIdx.z;
|
54 |
+
int c_idx = blockIdx.y;
|
55 |
+
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
56 |
+
int pt_idx = index / nsample;
|
57 |
+
if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;
|
58 |
+
|
59 |
+
int sample_idx = index % nsample;
|
60 |
+
|
61 |
+
idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx;
|
62 |
+
int in_idx = bs_idx * c * n + c_idx * n + idx[0];
|
63 |
+
int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx;
|
64 |
+
|
65 |
+
out[out_idx] = points[in_idx];
|
66 |
+
}
|
67 |
+
|
68 |
+
|
69 |
+
void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample,
|
70 |
+
const float *points, const int *idx, float *out, cudaStream_t stream) {
|
71 |
+
// points: (B, C, N)
|
72 |
+
// idx: (B, npoints, nsample)
|
73 |
+
// output:
|
74 |
+
// out: (B, C, npoints, nsample)
|
75 |
+
cudaError_t err;
|
76 |
+
dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
|
77 |
+
dim3 threads(THREADS_PER_BLOCK);
|
78 |
+
|
79 |
+
group_points_kernel_fast<<<blocks, threads, 0, stream>>>(b, c, n, npoints, nsample, points, idx, out);
|
80 |
+
// cudaDeviceSynchronize(); // for using printf in kernel function
|
81 |
+
err = cudaGetLastError();
|
82 |
+
if (cudaSuccess != err) {
|
83 |
+
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
84 |
+
exit(-1);
|
85 |
+
}
|
86 |
+
}
|
lib/pointnet2/src/group_points_gpu.h
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _GROUP_POINTS_GPU_H
|
2 |
+
#define _GROUP_POINTS_GPU_H
|
3 |
+
|
4 |
+
#include <torch/serialize/tensor.h>
|
5 |
+
#include <cuda.h>
|
6 |
+
#include <cuda_runtime_api.h>
|
7 |
+
#include <vector>
|
8 |
+
|
9 |
+
|
10 |
+
int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample,
|
11 |
+
at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor);
|
12 |
+
|
13 |
+
void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample,
|
14 |
+
const float *points, const int *idx, float *out, cudaStream_t stream);
|
15 |
+
|
16 |
+
int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample,
|
17 |
+
at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor);
|
18 |
+
|
19 |
+
void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample,
|
20 |
+
const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream);
|
21 |
+
|
22 |
+
#endif
|
lib/pointnet2/src/interpolate.cpp
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/serialize/tensor.h>
|
2 |
+
#include <vector>
|
3 |
+
#include <math.h>
|
4 |
+
#include <stdio.h>
|
5 |
+
#include <stdlib.h>
|
6 |
+
#include <cuda.h>
|
7 |
+
#include <cuda_runtime_api.h>
|
8 |
+
#include "interpolate_gpu.h"
|
9 |
+
#include <ATen/cuda/CUDAContext.h>
|
10 |
+
#include <ATen/cuda/CUDAEvent.h>
|
11 |
+
|
12 |
+
|
13 |
+
void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor,
|
14 |
+
at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) {
|
15 |
+
const float *unknown = unknown_tensor.data<float>();
|
16 |
+
const float *known = known_tensor.data<float>();
|
17 |
+
float *dist2 = dist2_tensor.data<float>();
|
18 |
+
int *idx = idx_tensor.data<int>();
|
19 |
+
|
20 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
21 |
+
three_nn_kernel_launcher_fast(b, n, m, unknown, known, dist2, idx, stream);
|
22 |
+
}
|
23 |
+
|
24 |
+
|
25 |
+
void three_interpolate_wrapper_fast(int b, int c, int m, int n,
|
26 |
+
at::Tensor points_tensor,
|
27 |
+
at::Tensor idx_tensor,
|
28 |
+
at::Tensor weight_tensor,
|
29 |
+
at::Tensor out_tensor) {
|
30 |
+
|
31 |
+
const float *points = points_tensor.data<float>();
|
32 |
+
const float *weight = weight_tensor.data<float>();
|
33 |
+
float *out = out_tensor.data<float>();
|
34 |
+
const int *idx = idx_tensor.data<int>();
|
35 |
+
|
36 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
37 |
+
three_interpolate_kernel_launcher_fast(b, c, m, n, points, idx, weight, out, stream);
|
38 |
+
}
|
39 |
+
|
40 |
+
void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m,
|
41 |
+
at::Tensor grad_out_tensor,
|
42 |
+
at::Tensor idx_tensor,
|
43 |
+
at::Tensor weight_tensor,
|
44 |
+
at::Tensor grad_points_tensor) {
|
45 |
+
|
46 |
+
const float *grad_out = grad_out_tensor.data<float>();
|
47 |
+
const float *weight = weight_tensor.data<float>();
|
48 |
+
float *grad_points = grad_points_tensor.data<float>();
|
49 |
+
const int *idx = idx_tensor.data<int>();
|
50 |
+
|
51 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
52 |
+
three_interpolate_grad_kernel_launcher_fast(b, c, n, m, grad_out, idx, weight, grad_points, stream);
|
53 |
+
}
|
lib/pointnet2/src/interpolate_gpu.cu
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <math.h>
|
2 |
+
#include <stdio.h>
|
3 |
+
#include <stdlib.h>
|
4 |
+
|
5 |
+
#include "cuda_utils.h"
|
6 |
+
#include "interpolate_gpu.h"
|
7 |
+
|
8 |
+
|
9 |
+
__global__ void three_nn_kernel_fast(int b, int n, int m, const float *__restrict__ unknown,
|
10 |
+
const float *__restrict__ known, float *__restrict__ dist2, int *__restrict__ idx) {
|
11 |
+
// unknown: (B, N, 3)
|
12 |
+
// known: (B, M, 3)
|
13 |
+
// output:
|
14 |
+
// dist2: (B, N, 3)
|
15 |
+
// idx: (B, N, 3)
|
16 |
+
|
17 |
+
int bs_idx = blockIdx.y;
|
18 |
+
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
19 |
+
if (bs_idx >= b || pt_idx >= n) return;
|
20 |
+
|
21 |
+
unknown += bs_idx * n * 3 + pt_idx * 3;
|
22 |
+
known += bs_idx * m * 3;
|
23 |
+
dist2 += bs_idx * n * 3 + pt_idx * 3;
|
24 |
+
idx += bs_idx * n * 3 + pt_idx * 3;
|
25 |
+
|
26 |
+
float ux = unknown[0];
|
27 |
+
float uy = unknown[1];
|
28 |
+
float uz = unknown[2];
|
29 |
+
|
30 |
+
double best1 = 1e40, best2 = 1e40, best3 = 1e40;
|
31 |
+
int besti1 = 0, besti2 = 0, besti3 = 0;
|
32 |
+
for (int k = 0; k < m; ++k) {
|
33 |
+
float x = known[k * 3 + 0];
|
34 |
+
float y = known[k * 3 + 1];
|
35 |
+
float z = known[k * 3 + 2];
|
36 |
+
float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
|
37 |
+
if (d < best1) {
|
38 |
+
best3 = best2; besti3 = besti2;
|
39 |
+
best2 = best1; besti2 = besti1;
|
40 |
+
best1 = d; besti1 = k;
|
41 |
+
}
|
42 |
+
else if (d < best2) {
|
43 |
+
best3 = best2; besti3 = besti2;
|
44 |
+
best2 = d; besti2 = k;
|
45 |
+
}
|
46 |
+
else if (d < best3) {
|
47 |
+
best3 = d; besti3 = k;
|
48 |
+
}
|
49 |
+
}
|
50 |
+
dist2[0] = best1; dist2[1] = best2; dist2[2] = best3;
|
51 |
+
idx[0] = besti1; idx[1] = besti2; idx[2] = besti3;
|
52 |
+
}
|
53 |
+
|
54 |
+
|
55 |
+
void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown,
|
56 |
+
const float *known, float *dist2, int *idx, cudaStream_t stream) {
|
57 |
+
// unknown: (B, N, 3)
|
58 |
+
// known: (B, M, 3)
|
59 |
+
// output:
|
60 |
+
// dist2: (B, N, 3)
|
61 |
+
// idx: (B, N, 3)
|
62 |
+
|
63 |
+
cudaError_t err;
|
64 |
+
dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row)
|
65 |
+
dim3 threads(THREADS_PER_BLOCK);
|
66 |
+
|
67 |
+
three_nn_kernel_fast<<<blocks, threads, 0, stream>>>(b, n, m, unknown, known, dist2, idx);
|
68 |
+
|
69 |
+
err = cudaGetLastError();
|
70 |
+
if (cudaSuccess != err) {
|
71 |
+
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
72 |
+
exit(-1);
|
73 |
+
}
|
74 |
+
}
|
75 |
+
|
76 |
+
|
77 |
+
__global__ void three_interpolate_kernel_fast(int b, int c, int m, int n, const float *__restrict__ points,
|
78 |
+
const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ out) {
|
79 |
+
// points: (B, C, M)
|
80 |
+
// idx: (B, N, 3)
|
81 |
+
// weight: (B, N, 3)
|
82 |
+
// output:
|
83 |
+
// out: (B, C, N)
|
84 |
+
|
85 |
+
int bs_idx = blockIdx.z;
|
86 |
+
int c_idx = blockIdx.y;
|
87 |
+
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
88 |
+
|
89 |
+
if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
|
90 |
+
|
91 |
+
weight += bs_idx * n * 3 + pt_idx * 3;
|
92 |
+
points += bs_idx * c * m + c_idx * m;
|
93 |
+
idx += bs_idx * n * 3 + pt_idx * 3;
|
94 |
+
out += bs_idx * c * n + c_idx * n;
|
95 |
+
|
96 |
+
out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] + weight[2] * points[idx[2]];
|
97 |
+
}
|
98 |
+
|
99 |
+
void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n,
|
100 |
+
const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream) {
|
101 |
+
// points: (B, C, M)
|
102 |
+
// idx: (B, N, 3)
|
103 |
+
// weight: (B, N, 3)
|
104 |
+
// output:
|
105 |
+
// out: (B, C, N)
|
106 |
+
|
107 |
+
cudaError_t err;
|
108 |
+
dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
|
109 |
+
dim3 threads(THREADS_PER_BLOCK);
|
110 |
+
three_interpolate_kernel_fast<<<blocks, threads, 0, stream>>>(b, c, m, n, points, idx, weight, out);
|
111 |
+
|
112 |
+
err = cudaGetLastError();
|
113 |
+
if (cudaSuccess != err) {
|
114 |
+
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
115 |
+
exit(-1);
|
116 |
+
}
|
117 |
+
}
|
118 |
+
|
119 |
+
|
120 |
+
__global__ void three_interpolate_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out,
|
121 |
+
const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ grad_points) {
|
122 |
+
// grad_out: (B, C, N)
|
123 |
+
// weight: (B, N, 3)
|
124 |
+
// output:
|
125 |
+
// grad_points: (B, C, M)
|
126 |
+
|
127 |
+
int bs_idx = blockIdx.z;
|
128 |
+
int c_idx = blockIdx.y;
|
129 |
+
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
130 |
+
|
131 |
+
if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
|
132 |
+
|
133 |
+
grad_out += bs_idx * c * n + c_idx * n + pt_idx;
|
134 |
+
weight += bs_idx * n * 3 + pt_idx * 3;
|
135 |
+
grad_points += bs_idx * c * m + c_idx * m;
|
136 |
+
idx += bs_idx * n * 3 + pt_idx * 3;
|
137 |
+
|
138 |
+
|
139 |
+
atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]);
|
140 |
+
atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]);
|
141 |
+
atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]);
|
142 |
+
}
|
143 |
+
|
144 |
+
void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out,
|
145 |
+
const int *idx, const float *weight, float *grad_points, cudaStream_t stream) {
|
146 |
+
// grad_out: (B, C, N)
|
147 |
+
// weight: (B, N, 3)
|
148 |
+
// output:
|
149 |
+
// grad_points: (B, C, M)
|
150 |
+
|
151 |
+
cudaError_t err;
|
152 |
+
dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
|
153 |
+
dim3 threads(THREADS_PER_BLOCK);
|
154 |
+
three_interpolate_grad_kernel_fast<<<blocks, threads, 0, stream>>>(b, c, n, m, grad_out, idx, weight, grad_points);
|
155 |
+
|
156 |
+
err = cudaGetLastError();
|
157 |
+
if (cudaSuccess != err) {
|
158 |
+
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
159 |
+
exit(-1);
|
160 |
+
}
|
161 |
+
}
|
lib/pointnet2/src/interpolate_gpu.h
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _INTERPOLATE_GPU_H
|
2 |
+
#define _INTERPOLATE_GPU_H
|
3 |
+
|
4 |
+
#include <torch/serialize/tensor.h>
|
5 |
+
#include<vector>
|
6 |
+
#include <cuda.h>
|
7 |
+
#include <cuda_runtime_api.h>
|
8 |
+
|
9 |
+
|
10 |
+
void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor,
|
11 |
+
at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor);
|
12 |
+
|
13 |
+
void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown,
|
14 |
+
const float *known, float *dist2, int *idx, cudaStream_t stream);
|
15 |
+
|
16 |
+
|
17 |
+
void three_interpolate_wrapper_fast(int b, int c, int m, int n, at::Tensor points_tensor,
|
18 |
+
at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor);
|
19 |
+
|
20 |
+
void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n,
|
21 |
+
const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream);
|
22 |
+
|
23 |
+
|
24 |
+
void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, at::Tensor grad_out_tensor,
|
25 |
+
at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_points_tensor);
|
26 |
+
|
27 |
+
void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out,
|
28 |
+
const int *idx, const float *weight, float *grad_points, cudaStream_t stream);
|
29 |
+
|
30 |
+
#endif
|
lib/pointnet2/src/pointnet2_api.cpp
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/serialize/tensor.h>
|
2 |
+
#include <torch/extension.h>
|
3 |
+
|
4 |
+
#include "ball_query_gpu.h"
|
5 |
+
#include "group_points_gpu.h"
|
6 |
+
#include "sampling_gpu.h"
|
7 |
+
#include "interpolate_gpu.h"
|
8 |
+
|
9 |
+
|
10 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
11 |
+
m.def("ball_query_wrapper", &ball_query_wrapper_fast, "ball_query_wrapper_fast");
|
12 |
+
|
13 |
+
m.def("group_points_wrapper", &group_points_wrapper_fast, "group_points_wrapper_fast");
|
14 |
+
m.def("group_points_grad_wrapper", &group_points_grad_wrapper_fast, "group_points_grad_wrapper_fast");
|
15 |
+
|
16 |
+
m.def("gather_points_wrapper", &gather_points_wrapper_fast, "gather_points_wrapper_fast");
|
17 |
+
m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper_fast, "gather_points_grad_wrapper_fast");
|
18 |
+
|
19 |
+
m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper");
|
20 |
+
|
21 |
+
m.def("three_nn_wrapper", &three_nn_wrapper_fast, "three_nn_wrapper_fast");
|
22 |
+
m.def("three_interpolate_wrapper", &three_interpolate_wrapper_fast, "three_interpolate_wrapper_fast");
|
23 |
+
m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_fast, "three_interpolate_grad_wrapper_fast");
|
24 |
+
}
|
lib/pointnet2/src/sampling.cpp
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/serialize/tensor.h>
|
2 |
+
#include <ATen/cuda/CUDAContext.h>
|
3 |
+
#include <vector>
|
4 |
+
#include <ATen/cuda/CUDAContext.h>
|
5 |
+
#include <ATen/cuda/CUDAEvent.h>
|
6 |
+
#include "sampling_gpu.h"
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
int gather_points_wrapper_fast(int b, int c, int n, int npoints,
|
11 |
+
at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor){
|
12 |
+
const float *points = points_tensor.data<float>();
|
13 |
+
const int *idx = idx_tensor.data<int>();
|
14 |
+
float *out = out_tensor.data<float>();
|
15 |
+
|
16 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
17 |
+
gather_points_kernel_launcher_fast(b, c, n, npoints, points, idx, out, stream);
|
18 |
+
return 1;
|
19 |
+
}
|
20 |
+
|
21 |
+
|
22 |
+
int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints,
|
23 |
+
at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) {
|
24 |
+
|
25 |
+
const float *grad_out = grad_out_tensor.data<float>();
|
26 |
+
const int *idx = idx_tensor.data<int>();
|
27 |
+
float *grad_points = grad_points_tensor.data<float>();
|
28 |
+
|
29 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
30 |
+
gather_points_grad_kernel_launcher_fast(b, c, n, npoints, grad_out, idx, grad_points, stream);
|
31 |
+
return 1;
|
32 |
+
}
|
33 |
+
|
34 |
+
|
35 |
+
int furthest_point_sampling_wrapper(int b, int n, int m,
|
36 |
+
at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) {
|
37 |
+
|
38 |
+
const float *points = points_tensor.data<float>();
|
39 |
+
float *temp = temp_tensor.data<float>();
|
40 |
+
int *idx = idx_tensor.data<int>();
|
41 |
+
|
42 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
43 |
+
furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx, stream);
|
44 |
+
return 1;
|
45 |
+
}
|
lib/pointnet2/src/sampling_gpu.cu
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <stdio.h>
|
2 |
+
#include <stdlib.h>
|
3 |
+
|
4 |
+
#include "cuda_utils.h"
|
5 |
+
#include "sampling_gpu.h"
|
6 |
+
|
7 |
+
|
8 |
+
__global__ void gather_points_kernel_fast(int b, int c, int n, int m,
|
9 |
+
const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) {
|
10 |
+
// points: (B, C, N)
|
11 |
+
// idx: (B, M)
|
12 |
+
// output:
|
13 |
+
// out: (B, C, M)
|
14 |
+
|
15 |
+
int bs_idx = blockIdx.z;
|
16 |
+
int c_idx = blockIdx.y;
|
17 |
+
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
18 |
+
if (bs_idx >= b || c_idx >= c || pt_idx >= m) return;
|
19 |
+
|
20 |
+
out += bs_idx * c * m + c_idx * m + pt_idx;
|
21 |
+
idx += bs_idx * m + pt_idx;
|
22 |
+
points += bs_idx * c * n + c_idx * n;
|
23 |
+
out[0] = points[idx[0]];
|
24 |
+
}
|
25 |
+
|
26 |
+
void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints,
|
27 |
+
const float *points, const int *idx, float *out, cudaStream_t stream) {
|
28 |
+
// points: (B, C, N)
|
29 |
+
// idx: (B, npoints)
|
30 |
+
// output:
|
31 |
+
// out: (B, C, npoints)
|
32 |
+
|
33 |
+
cudaError_t err;
|
34 |
+
dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
|
35 |
+
dim3 threads(THREADS_PER_BLOCK);
|
36 |
+
|
37 |
+
gather_points_kernel_fast<<<blocks, threads, 0, stream>>>(b, c, n, npoints, points, idx, out);
|
38 |
+
|
39 |
+
err = cudaGetLastError();
|
40 |
+
if (cudaSuccess != err) {
|
41 |
+
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
42 |
+
exit(-1);
|
43 |
+
}
|
44 |
+
}
|
45 |
+
|
46 |
+
__global__ void gather_points_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out,
|
47 |
+
const int *__restrict__ idx, float *__restrict__ grad_points) {
|
48 |
+
// grad_out: (B, C, M)
|
49 |
+
// idx: (B, M)
|
50 |
+
// output:
|
51 |
+
// grad_points: (B, C, N)
|
52 |
+
|
53 |
+
int bs_idx = blockIdx.z;
|
54 |
+
int c_idx = blockIdx.y;
|
55 |
+
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
56 |
+
if (bs_idx >= b || c_idx >= c || pt_idx >= m) return;
|
57 |
+
|
58 |
+
grad_out += bs_idx * c * m + c_idx * m + pt_idx;
|
59 |
+
idx += bs_idx * m + pt_idx;
|
60 |
+
grad_points += bs_idx * c * n + c_idx * n;
|
61 |
+
|
62 |
+
atomicAdd(grad_points + idx[0], grad_out[0]);
|
63 |
+
}
|
64 |
+
|
65 |
+
void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints,
|
66 |
+
const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) {
|
67 |
+
// grad_out: (B, C, npoints)
|
68 |
+
// idx: (B, npoints)
|
69 |
+
// output:
|
70 |
+
// grad_points: (B, C, N)
|
71 |
+
|
72 |
+
cudaError_t err;
|
73 |
+
dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
|
74 |
+
dim3 threads(THREADS_PER_BLOCK);
|
75 |
+
|
76 |
+
gather_points_grad_kernel_fast<<<blocks, threads, 0, stream>>>(b, c, n, npoints, grad_out, idx, grad_points);
|
77 |
+
|
78 |
+
err = cudaGetLastError();
|
79 |
+
if (cudaSuccess != err) {
|
80 |
+
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
81 |
+
exit(-1);
|
82 |
+
}
|
83 |
+
}
|
84 |
+
|
85 |
+
|
86 |
+
__device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, int idx1, int idx2){
|
87 |
+
const float v1 = dists[idx1], v2 = dists[idx2];
|
88 |
+
const int i1 = dists_i[idx1], i2 = dists_i[idx2];
|
89 |
+
dists[idx1] = max(v1, v2);
|
90 |
+
dists_i[idx1] = v2 > v1 ? i2 : i1;
|
91 |
+
}
|
92 |
+
|
93 |
+
template <unsigned int block_size>
|
94 |
+
__global__ void furthest_point_sampling_kernel(int b, int n, int m,
|
95 |
+
const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) {
|
96 |
+
// dataset: (B, N, 3)
|
97 |
+
// tmp: (B, N)
|
98 |
+
// output:
|
99 |
+
// idx: (B, M)
|
100 |
+
|
101 |
+
if (m <= 0) return;
|
102 |
+
__shared__ float dists[block_size];
|
103 |
+
__shared__ int dists_i[block_size];
|
104 |
+
|
105 |
+
int batch_index = blockIdx.x;
|
106 |
+
dataset += batch_index * n * 3;
|
107 |
+
temp += batch_index * n;
|
108 |
+
idxs += batch_index * m;
|
109 |
+
|
110 |
+
int tid = threadIdx.x;
|
111 |
+
const int stride = block_size;
|
112 |
+
|
113 |
+
int old = 0;
|
114 |
+
if (threadIdx.x == 0)
|
115 |
+
idxs[0] = old;
|
116 |
+
|
117 |
+
__syncthreads();
|
118 |
+
for (int j = 1; j < m; j++) {
|
119 |
+
int besti = 0;
|
120 |
+
float best = -1;
|
121 |
+
float x1 = dataset[old * 3 + 0];
|
122 |
+
float y1 = dataset[old * 3 + 1];
|
123 |
+
float z1 = dataset[old * 3 + 2];
|
124 |
+
for (int k = tid; k < n; k += stride) {
|
125 |
+
float x2, y2, z2;
|
126 |
+
x2 = dataset[k * 3 + 0];
|
127 |
+
y2 = dataset[k * 3 + 1];
|
128 |
+
z2 = dataset[k * 3 + 2];
|
129 |
+
// float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
|
130 |
+
// if (mag <= 1e-3)
|
131 |
+
// continue;
|
132 |
+
|
133 |
+
float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
|
134 |
+
float d2 = min(d, temp[k]);
|
135 |
+
temp[k] = d2;
|
136 |
+
besti = d2 > best ? k : besti;
|
137 |
+
best = d2 > best ? d2 : best;
|
138 |
+
}
|
139 |
+
dists[tid] = best;
|
140 |
+
dists_i[tid] = besti;
|
141 |
+
__syncthreads();
|
142 |
+
|
143 |
+
if (block_size >= 1024) {
|
144 |
+
if (tid < 512) {
|
145 |
+
__update(dists, dists_i, tid, tid + 512);
|
146 |
+
}
|
147 |
+
__syncthreads();
|
148 |
+
}
|
149 |
+
|
150 |
+
if (block_size >= 512) {
|
151 |
+
if (tid < 256) {
|
152 |
+
__update(dists, dists_i, tid, tid + 256);
|
153 |
+
}
|
154 |
+
__syncthreads();
|
155 |
+
}
|
156 |
+
if (block_size >= 256) {
|
157 |
+
if (tid < 128) {
|
158 |
+
__update(dists, dists_i, tid, tid + 128);
|
159 |
+
}
|
160 |
+
__syncthreads();
|
161 |
+
}
|
162 |
+
if (block_size >= 128) {
|
163 |
+
if (tid < 64) {
|
164 |
+
__update(dists, dists_i, tid, tid + 64);
|
165 |
+
}
|
166 |
+
__syncthreads();
|
167 |
+
}
|
168 |
+
if (block_size >= 64) {
|
169 |
+
if (tid < 32) {
|
170 |
+
__update(dists, dists_i, tid, tid + 32);
|
171 |
+
}
|
172 |
+
__syncthreads();
|
173 |
+
}
|
174 |
+
if (block_size >= 32) {
|
175 |
+
if (tid < 16) {
|
176 |
+
__update(dists, dists_i, tid, tid + 16);
|
177 |
+
}
|
178 |
+
__syncthreads();
|
179 |
+
}
|
180 |
+
if (block_size >= 16) {
|
181 |
+
if (tid < 8) {
|
182 |
+
__update(dists, dists_i, tid, tid + 8);
|
183 |
+
}
|
184 |
+
__syncthreads();
|
185 |
+
}
|
186 |
+
if (block_size >= 8) {
|
187 |
+
if (tid < 4) {
|
188 |
+
__update(dists, dists_i, tid, tid + 4);
|
189 |
+
}
|
190 |
+
__syncthreads();
|
191 |
+
}
|
192 |
+
if (block_size >= 4) {
|
193 |
+
if (tid < 2) {
|
194 |
+
__update(dists, dists_i, tid, tid + 2);
|
195 |
+
}
|
196 |
+
__syncthreads();
|
197 |
+
}
|
198 |
+
if (block_size >= 2) {
|
199 |
+
if (tid < 1) {
|
200 |
+
__update(dists, dists_i, tid, tid + 1);
|
201 |
+
}
|
202 |
+
__syncthreads();
|
203 |
+
}
|
204 |
+
|
205 |
+
old = dists_i[0];
|
206 |
+
if (tid == 0)
|
207 |
+
idxs[j] = old;
|
208 |
+
}
|
209 |
+
}
|
210 |
+
|
211 |
+
void furthest_point_sampling_kernel_launcher(int b, int n, int m,
|
212 |
+
const float *dataset, float *temp, int *idxs, cudaStream_t stream) {
|
213 |
+
// dataset: (B, N, 3)
|
214 |
+
// tmp: (B, N)
|
215 |
+
// output:
|
216 |
+
// idx: (B, M)
|
217 |
+
|
218 |
+
cudaError_t err;
|
219 |
+
unsigned int n_threads = opt_n_threads(n);
|
220 |
+
|
221 |
+
switch (n_threads) {
|
222 |
+
case 1024:
|
223 |
+
furthest_point_sampling_kernel<1024><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
|
224 |
+
case 512:
|
225 |
+
furthest_point_sampling_kernel<512><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
|
226 |
+
case 256:
|
227 |
+
furthest_point_sampling_kernel<256><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
|
228 |
+
case 128:
|
229 |
+
furthest_point_sampling_kernel<128><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
|
230 |
+
case 64:
|
231 |
+
furthest_point_sampling_kernel<64><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
|
232 |
+
case 32:
|
233 |
+
furthest_point_sampling_kernel<32><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
|
234 |
+
case 16:
|
235 |
+
furthest_point_sampling_kernel<16><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
|
236 |
+
case 8:
|
237 |
+
furthest_point_sampling_kernel<8><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
|
238 |
+
case 4:
|
239 |
+
furthest_point_sampling_kernel<4><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
|
240 |
+
case 2:
|
241 |
+
furthest_point_sampling_kernel<2><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
|
242 |
+
case 1:
|
243 |
+
furthest_point_sampling_kernel<1><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
|
244 |
+
default:
|
245 |
+
furthest_point_sampling_kernel<512><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
|
246 |
+
}
|
247 |
+
|
248 |
+
err = cudaGetLastError();
|
249 |
+
if (cudaSuccess != err) {
|
250 |
+
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
251 |
+
exit(-1);
|
252 |
+
}
|
253 |
+
}
|
lib/pointnet2/src/sampling_gpu.h
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _SAMPLING_GPU_H
|
2 |
+
#define _SAMPLING_GPU_H
|
3 |
+
|
4 |
+
#include <torch/serialize/tensor.h>
|
5 |
+
#include <ATen/cuda/CUDAContext.h>
|
6 |
+
#include<vector>
|
7 |
+
|
8 |
+
|
9 |
+
int gather_points_wrapper_fast(int b, int c, int n, int npoints,
|
10 |
+
at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor);
|
11 |
+
|
12 |
+
void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints,
|
13 |
+
const float *points, const int *idx, float *out, cudaStream_t stream);
|
14 |
+
|
15 |
+
|
16 |
+
int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints,
|
17 |
+
at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor);
|
18 |
+
|
19 |
+
void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints,
|
20 |
+
const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream);
|
21 |
+
|
22 |
+
|
23 |
+
int furthest_point_sampling_wrapper(int b, int n, int m,
|
24 |
+
at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor);
|
25 |
+
|
26 |
+
void furthest_point_sampling_kernel_launcher(int b, int n, int m,
|
27 |
+
const float *dataset, float *temp, int *idxs, cudaStream_t stream);
|
28 |
+
|
29 |
+
#endif
|
model/LLM/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import onellm
|
model/LLM/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (189 Bytes). View file
|
|
model/LLM/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (187 Bytes). View file
|
|
model/LLM/__pycache__/onellm.cpython-310.pyc
ADDED
Binary file (14 kB). View file
|
|
model/LLM/__pycache__/onellm.cpython-39.pyc
ADDED
Binary file (13.9 kB). View file
|
|
model/LLM/onellm.py
ADDED
@@ -0,0 +1,495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# This software may be used and distributed according to the terms of the GNU General Public License version 3.
|
3 |
+
|
4 |
+
from typing import Optional, Tuple
|
5 |
+
from dataclasses import dataclass
|
6 |
+
import math
|
7 |
+
import functools
|
8 |
+
import copy
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch import nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
import fairscale.nn.model_parallel.initialize as fs_init
|
15 |
+
from fairscale.nn.model_parallel.layers import (
|
16 |
+
ParallelEmbedding,
|
17 |
+
RowParallelLinear,
|
18 |
+
ColumnParallelLinear,
|
19 |
+
)
|
20 |
+
from ..components import RMSNorm
|
21 |
+
from flash_attn import flash_attn_func
|
22 |
+
|
23 |
+
import open_clip
|
24 |
+
|
25 |
+
|
26 |
+
default_linear_init = nn.init.xavier_uniform_
|
27 |
+
|
28 |
+
|
29 |
+
@dataclass
|
30 |
+
class ModelArgs:
|
31 |
+
dim: int = 512
|
32 |
+
n_layers: int = 8
|
33 |
+
n_heads: int = 8
|
34 |
+
vocab_size: int = -1 # defined later by tokenizer
|
35 |
+
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
36 |
+
norm_eps: float = 1e-5
|
37 |
+
|
38 |
+
max_batch_size: int = 32
|
39 |
+
max_seq_len: int = 2048
|
40 |
+
|
41 |
+
|
42 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
43 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)
|
44 |
+
[: (dim // 2)].float() / dim))
|
45 |
+
t = torch.arange(end, device=freqs.device) # type: ignore
|
46 |
+
freqs = torch.outer(t, freqs).float() # type: ignore
|
47 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
48 |
+
return freqs_cis
|
49 |
+
|
50 |
+
|
51 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
52 |
+
ndim = x.ndim
|
53 |
+
assert 0 <= 1 < ndim
|
54 |
+
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
55 |
+
shape = [d if i == 1 or i == ndim -
|
56 |
+
1 else 1 for i, d in enumerate(x.shape)]
|
57 |
+
return freqs_cis.view(*shape)
|
58 |
+
|
59 |
+
|
60 |
+
def apply_rotary_emb(
|
61 |
+
xq: torch.Tensor,
|
62 |
+
xk: torch.Tensor,
|
63 |
+
freqs_cis: torch.Tensor,
|
64 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
65 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
66 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
67 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
68 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
69 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
70 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
71 |
+
|
72 |
+
|
73 |
+
class Attention(nn.Module):
|
74 |
+
def __init__(self, args: ModelArgs):
|
75 |
+
super().__init__()
|
76 |
+
|
77 |
+
self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()
|
78 |
+
self.head_dim = args.dim // args.n_heads
|
79 |
+
|
80 |
+
self.wq = ColumnParallelLinear(
|
81 |
+
args.dim,
|
82 |
+
args.n_heads * self.head_dim,
|
83 |
+
bias=False,
|
84 |
+
gather_output=False,
|
85 |
+
init_method=default_linear_init,
|
86 |
+
)
|
87 |
+
self.wk = ColumnParallelLinear(
|
88 |
+
args.dim,
|
89 |
+
args.n_heads * self.head_dim,
|
90 |
+
bias=False,
|
91 |
+
gather_output=False,
|
92 |
+
init_method=default_linear_init,
|
93 |
+
)
|
94 |
+
self.wv = ColumnParallelLinear(
|
95 |
+
args.dim,
|
96 |
+
args.n_heads * self.head_dim,
|
97 |
+
bias=False,
|
98 |
+
gather_output=False,
|
99 |
+
init_method=default_linear_init,
|
100 |
+
)
|
101 |
+
self.wo = RowParallelLinear(
|
102 |
+
args.n_heads * self.head_dim,
|
103 |
+
args.dim,
|
104 |
+
bias=False,
|
105 |
+
input_is_parallel=True,
|
106 |
+
init_method=default_linear_init,
|
107 |
+
)
|
108 |
+
|
109 |
+
self.flash = True
|
110 |
+
self.k_cache, self.v_cache = None, None
|
111 |
+
|
112 |
+
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], prompt=None):
|
113 |
+
bsz, seqlen, _ = x.shape
|
114 |
+
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
115 |
+
|
116 |
+
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
117 |
+
xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
118 |
+
xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
119 |
+
|
120 |
+
if freqs_cis is not None:
|
121 |
+
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
122 |
+
|
123 |
+
if self.k_cache is None or self.v_cache is None:
|
124 |
+
keys, values = xk, xv
|
125 |
+
else:
|
126 |
+
self.k_cache = self.k_cache.to(xk)
|
127 |
+
self.v_cache = self.v_cache.to(xv)
|
128 |
+
self.k_cache[:bsz, start_pos: start_pos + seqlen, :, :] = xk
|
129 |
+
self.v_cache[:bsz, start_pos: start_pos + seqlen, :, :] = xv
|
130 |
+
keys = self.k_cache[:bsz, :start_pos + seqlen]
|
131 |
+
values = self.v_cache[:bsz, :start_pos + seqlen]
|
132 |
+
|
133 |
+
output = flash_attn_func(
|
134 |
+
xq, keys, values, dropout_p=0.0, causal=mask is not None)
|
135 |
+
output = output.contiguous().view(bsz, seqlen, -1)
|
136 |
+
|
137 |
+
return self.wo(output)
|
138 |
+
|
139 |
+
def allocate_kv_cache(self, max_batch_size: int, max_seq_len: int) -> None:
|
140 |
+
kv_cache_shape = (max_batch_size, max_seq_len,
|
141 |
+
self.n_local_heads, self.head_dim)
|
142 |
+
if self.k_cache is None or self.k_cache.size() != kv_cache_shape:
|
143 |
+
self.k_cache = torch.empty(kv_cache_shape)
|
144 |
+
if self.v_cache is None or self.v_cache.size() != kv_cache_shape:
|
145 |
+
self.v_cache = torch.empty(kv_cache_shape)
|
146 |
+
|
147 |
+
def destroy_kv_cache(self) -> None:
|
148 |
+
self.k_cache, self.v_cache = None, None
|
149 |
+
|
150 |
+
|
151 |
+
class FeedForward(nn.Module):
|
152 |
+
def __init__(
|
153 |
+
self,
|
154 |
+
dim: int,
|
155 |
+
hidden_dim: int,
|
156 |
+
multiple_of: int,
|
157 |
+
):
|
158 |
+
super().__init__()
|
159 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
160 |
+
hidden_dim = multiple_of * \
|
161 |
+
((hidden_dim + multiple_of - 1) // multiple_of)
|
162 |
+
|
163 |
+
self.w1 = ColumnParallelLinear(
|
164 |
+
dim, hidden_dim, bias=False, gather_output=False, init_method=default_linear_init,
|
165 |
+
)
|
166 |
+
self.w2 = RowParallelLinear(
|
167 |
+
hidden_dim, dim, bias=False, input_is_parallel=True, init_method=default_linear_init
|
168 |
+
)
|
169 |
+
self.w3 = ColumnParallelLinear(
|
170 |
+
dim, hidden_dim, bias=False, gather_output=False, init_method=default_linear_init
|
171 |
+
)
|
172 |
+
|
173 |
+
def _silu_gating(self, x, y):
|
174 |
+
return F.silu(x) * y
|
175 |
+
|
176 |
+
def forward(self, x):
|
177 |
+
return self.w2(self._silu_gating(self.w1(x), self.w3(x)))
|
178 |
+
|
179 |
+
|
180 |
+
class TransformerBlock(nn.Module):
|
181 |
+
def __init__(self, layer_id: int, args: ModelArgs):
|
182 |
+
super().__init__()
|
183 |
+
self.n_heads = args.n_heads
|
184 |
+
self.dim = args.dim
|
185 |
+
self.head_dim = args.dim // args.n_heads
|
186 |
+
self.attention = Attention(args)
|
187 |
+
self.feed_forward = FeedForward(
|
188 |
+
dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of
|
189 |
+
)
|
190 |
+
self.layer_id = layer_id
|
191 |
+
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
192 |
+
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
193 |
+
|
194 |
+
def _forward_ffn(self, h):
|
195 |
+
return h + self.feed_forward(self.ffn_norm(h))
|
196 |
+
|
197 |
+
def _forward_attention(self, x, start_pos, freqs_cis, mask, prompt):
|
198 |
+
return x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask, prompt)
|
199 |
+
|
200 |
+
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], prompt=None):
|
201 |
+
h = self._forward_attention(x, start_pos, freqs_cis, mask, prompt)
|
202 |
+
out = self._forward_ffn(h)
|
203 |
+
return out
|
204 |
+
|
205 |
+
|
206 |
+
class Mlp(nn.Module):
|
207 |
+
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
208 |
+
"""
|
209 |
+
|
210 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU):
|
211 |
+
super().__init__()
|
212 |
+
out_features = out_features or in_features
|
213 |
+
hidden_features = hidden_features or in_features
|
214 |
+
|
215 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
216 |
+
self.act = act_layer()
|
217 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
218 |
+
|
219 |
+
def forward(self, x):
|
220 |
+
x = self.fc1(x)
|
221 |
+
x = self.act(x)
|
222 |
+
x = self.fc2(x)
|
223 |
+
return x
|
224 |
+
|
225 |
+
|
226 |
+
class Transformer(nn.Module):
|
227 |
+
def __init__(self, params: ModelArgs):
|
228 |
+
super().__init__()
|
229 |
+
self.params = params
|
230 |
+
self.vocab_size = params.vocab_size
|
231 |
+
self.n_layers = params.n_layers
|
232 |
+
self.tok_embeddings = ParallelEmbedding(
|
233 |
+
params.vocab_size, params.dim, init_method=nn.init.normal_,
|
234 |
+
)
|
235 |
+
|
236 |
+
self.layers = torch.nn.ModuleList()
|
237 |
+
for layer_id in range(params.n_layers):
|
238 |
+
self.layers.append(TransformerBlock(layer_id, params))
|
239 |
+
|
240 |
+
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
241 |
+
self.output = ColumnParallelLinear(
|
242 |
+
params.dim, params.vocab_size, bias=False, init_method=default_linear_init,
|
243 |
+
)
|
244 |
+
|
245 |
+
self.freqs_cis = precompute_freqs_cis(
|
246 |
+
self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
|
247 |
+
)
|
248 |
+
|
249 |
+
# load clip
|
250 |
+
self.clip, _, _ = open_clip.create_model_and_transforms(
|
251 |
+
'ViT-L-14', pretrained='openai')
|
252 |
+
for param in self.clip.parameters():
|
253 |
+
param.requires_grad = False
|
254 |
+
param.data = param.data.half()
|
255 |
+
self.clip.transformer = None
|
256 |
+
|
257 |
+
self.image_words = 30
|
258 |
+
self.cache_image_words = 0 # for inference
|
259 |
+
|
260 |
+
clip_width = self.clip.visual.conv1.out_channels
|
261 |
+
# create modal shared modules
|
262 |
+
self.resample_layers = nn.ModuleDict()
|
263 |
+
self.num_experts = 3
|
264 |
+
self.num_resample_layers = 8
|
265 |
+
for expert in range(self.num_experts):
|
266 |
+
expert = str(expert)
|
267 |
+
self.resample_layers[expert] = nn.ModuleList()
|
268 |
+
resampler_params = copy.deepcopy(params)
|
269 |
+
resampler_params.n_heads = 16
|
270 |
+
resampler_params.dim = clip_width
|
271 |
+
for layer_id in range(self.num_resample_layers):
|
272 |
+
self.resample_layers[expert].append(
|
273 |
+
TransformerBlock(layer_id, resampler_params))
|
274 |
+
|
275 |
+
self.conv1 = nn.ModuleDict()
|
276 |
+
self.positional_embedding = nn.ParameterDict()
|
277 |
+
self.resample_tokens = nn.ParameterDict()
|
278 |
+
self.clip_proj1 = nn.ModuleDict()
|
279 |
+
self.clip_proj2 = nn.ModuleDict()
|
280 |
+
self.routers = nn.ModuleDict()
|
281 |
+
self.start_tag = nn.ParameterDict()
|
282 |
+
self.end_tag = nn.ParameterDict()
|
283 |
+
# self.modals = ['image', 'audio', 'point', 'video', 'rgbd', 'rgbn', 'fmri', 'imu']
|
284 |
+
self.modals = ['image', 'audio', 'video', 'rgbd', 'rgbn', 'fmri', 'imu']
|
285 |
+
for modal in self.modals:
|
286 |
+
if modal in ['image', 'video', 'rgbn', 'rgbn']:
|
287 |
+
modal_tokens = 256 + 1
|
288 |
+
pass
|
289 |
+
elif modal == 'audio':
|
290 |
+
self.conv1[modal] = nn.Conv2d(
|
291 |
+
1, clip_width, kernel_size=(16, 16), stride=(10, 10))
|
292 |
+
modal_tokens = 1212 + 1
|
293 |
+
self.positional_embedding[modal] = nn.Parameter(
|
294 |
+
torch.empty([modal_tokens, clip_width]))
|
295 |
+
nn.init.normal_(self.positional_embedding[modal], std=0.02)
|
296 |
+
elif modal == 'point':
|
297 |
+
from lib.point_utils import PointPatchEmbed
|
298 |
+
self.conv1[modal] = PointPatchEmbed(
|
299 |
+
in_channels=6, channels=clip_width)
|
300 |
+
modal_tokens = 1024 + 1
|
301 |
+
self.positional_embedding[modal] = nn.Parameter(
|
302 |
+
torch.empty([modal_tokens, clip_width]))
|
303 |
+
nn.init.normal_(self.positional_embedding[modal], std=0.02)
|
304 |
+
elif modal == 'fmri':
|
305 |
+
self.conv1[modal] = nn.Linear(15724, 8192)
|
306 |
+
self.positional_embedding[modal] = nn.Parameter(
|
307 |
+
torch.empty([8+1, clip_width]))
|
308 |
+
nn.init.normal_(self.positional_embedding[modal], std=0.02)
|
309 |
+
elif modal == 'imu':
|
310 |
+
self.conv1[modal] = nn.Conv1d(
|
311 |
+
in_channels=6, out_channels=clip_width, kernel_size=10, bias=False)
|
312 |
+
self.positional_embedding[modal] = nn.Parameter(
|
313 |
+
torch.empty([391+1, clip_width]))
|
314 |
+
nn.init.normal_(self.positional_embedding[modal], std=0.02)
|
315 |
+
|
316 |
+
self.routers[modal] = Mlp(
|
317 |
+
clip_width, clip_width * 4, self.num_experts)
|
318 |
+
|
319 |
+
self.resample_tokens[modal] = nn.Parameter(
|
320 |
+
torch.empty([1, 30, resampler_params.dim]))
|
321 |
+
nn.init.normal_(self.resample_tokens[modal], std=0.02)
|
322 |
+
|
323 |
+
self.clip_proj1[modal] = nn.Sequential(
|
324 |
+
nn.Linear(clip_width, resampler_params.dim),
|
325 |
+
nn.LayerNorm(resampler_params.dim))
|
326 |
+
|
327 |
+
self.clip_proj2[modal] = nn.Sequential(
|
328 |
+
nn.Linear(resampler_params.dim, params.dim),
|
329 |
+
nn.LayerNorm(params.dim))
|
330 |
+
|
331 |
+
self.start_tag[modal] = nn.Parameter(torch.rand(1, 1, params.dim))
|
332 |
+
self.end_tag[modal] = nn.Parameter(torch.rand(1, 1, params.dim))
|
333 |
+
|
334 |
+
# @torch.no_grad()
|
335 |
+
|
336 |
+
def clip_encode_image(self, x, modal='image'):
|
337 |
+
# shape = [*, width, grid ** 2]
|
338 |
+
x = x.reshape(x.shape[0], x.shape[1], -1)
|
339 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
340 |
+
|
341 |
+
x = torch.cat([self.clip.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1,
|
342 |
+
x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
343 |
+
|
344 |
+
# use pretrained pos embeding for rest modalities
|
345 |
+
pos_embedding = self.clip.visual.positional_embedding
|
346 |
+
if modal in ['audio', 'point', 'fmri', 'imu']:
|
347 |
+
pos_embedding = self.positional_embedding[modal]
|
348 |
+
|
349 |
+
x = x + pos_embedding.to(x.dtype)
|
350 |
+
x = self.clip.visual.ln_pre(x)
|
351 |
+
|
352 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
353 |
+
x = self.clip.visual.transformer(x)
|
354 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
355 |
+
|
356 |
+
# preserve all spatial tokens
|
357 |
+
x = self.clip.visual.ln_post(x[:, :, :])
|
358 |
+
|
359 |
+
# if self.clip.visual.proj is not None:
|
360 |
+
# x = x @ self.clip.visual.proj
|
361 |
+
|
362 |
+
return x
|
363 |
+
|
364 |
+
def encode_image(self, x, modal='image'):
|
365 |
+
bsz = x.size(0)
|
366 |
+
T = 1
|
367 |
+
if modal in ['image']:
|
368 |
+
# modified from CLIP
|
369 |
+
x = self.clip.visual.conv1(x) # shape = [*, width, grid, grid]
|
370 |
+
elif modal in ['audio', 'imu']:
|
371 |
+
x = self.conv1[modal](x)
|
372 |
+
elif modal == 'point':
|
373 |
+
# [B, 16384, 6] -> [B, 1024, 1024, 1]
|
374 |
+
x = self.conv1[modal](x.float()).to(x.dtype)
|
375 |
+
elif modal in ['video', 'rgbd', 'rgbn']:
|
376 |
+
# [B, 15, 3, 224, 224]
|
377 |
+
B, T = x.shape[:2]
|
378 |
+
bsz = B * T
|
379 |
+
x = x.reshape(bsz, *x.shape[2:])
|
380 |
+
x = self.clip.visual.conv1(x)
|
381 |
+
elif modal == 'fmri':
|
382 |
+
x = self.conv1[modal](x)
|
383 |
+
# [B, 1, 8196] -> [B, 1024, 8]
|
384 |
+
x = x.reshape(x.size(0), self.clip.visual.conv1.out_channels, -1)
|
385 |
+
|
386 |
+
image_feats = self.clip_encode_image(x, modal=modal)
|
387 |
+
# take mean on time dimension
|
388 |
+
# all inputs are reduced to [B, L, D]
|
389 |
+
bsz = int(bsz / T)
|
390 |
+
image_feats = image_feats.reshape(
|
391 |
+
bsz, T, *image_feats.shape[1:]).mean(dim=1)
|
392 |
+
|
393 |
+
image_feats = self.clip_proj1[modal](image_feats)
|
394 |
+
image_feats = torch.cat(
|
395 |
+
[self.resample_tokens[modal].repeat(bsz, 1, 1), image_feats], dim=1)
|
396 |
+
|
397 |
+
# routing modalites
|
398 |
+
# [B, L, D]->[B, L, N]
|
399 |
+
routing_weights = self.routers[modal](image_feats).sigmoid()
|
400 |
+
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
|
401 |
+
|
402 |
+
image_feats_experts = []
|
403 |
+
for expert_id in range(self.num_experts):
|
404 |
+
image_feats_expert = image_feats
|
405 |
+
for layer in self.resample_layers[str(expert_id)]:
|
406 |
+
image_feats_expert = layer(image_feats_expert, 0, None, None)
|
407 |
+
|
408 |
+
image_feats_expert = image_feats_expert[:, :self.resample_tokens[modal].size(1)]
|
409 |
+
routing_weight = routing_weights[:, :self.resample_tokens[modal].size(
|
410 |
+
1), expert_id]
|
411 |
+
# [B, L, D] * [B, L, 1]
|
412 |
+
image_feats_expert = image_feats_expert * routing_weight[:, :, None]
|
413 |
+
|
414 |
+
image_feats_experts.append(image_feats_expert)
|
415 |
+
|
416 |
+
image_feats = sum(image_feats_experts)
|
417 |
+
image_feats = self.clip_proj2[modal](image_feats)
|
418 |
+
|
419 |
+
return image_feats
|
420 |
+
|
421 |
+
def forward(self, examples, image=None, modal='image'):
|
422 |
+
self._destroy_kv_cache() # training always disables kv cache
|
423 |
+
modal = modal[0]
|
424 |
+
_bsz, seqlen = examples.shape
|
425 |
+
h = self.tok_embeddings(examples)
|
426 |
+
self.freqs_cis = self.freqs_cis.to(h.device)
|
427 |
+
|
428 |
+
start_pos = 0
|
429 |
+
prefix_len = 0
|
430 |
+
if image is not None:
|
431 |
+
h_bos, h_caption = h[:, :1], h[:, 1:]
|
432 |
+
image_tokens = self.encode_image(image, modal)
|
433 |
+
h = torch.cat((h_bos, self.start_tag[modal].expand(
|
434 |
+
_bsz, -1, -1), image_tokens, self.end_tag[modal].expand(_bsz, -1, -1), h_caption), dim=1)
|
435 |
+
# bos + image token + start_tag[modal], end_tag[modal] is used for caption generation
|
436 |
+
prefix_len = image_tokens.shape[1] + 1 + 1
|
437 |
+
seqlen = h.shape[1]
|
438 |
+
|
439 |
+
freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen]
|
440 |
+
mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=h.device)
|
441 |
+
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
|
442 |
+
for layer in self.layers:
|
443 |
+
h = layer(h, start_pos, freqs_cis, mask)
|
444 |
+
h = self.norm(h)
|
445 |
+
output = self.output(h[:, prefix_len:, :])
|
446 |
+
return output
|
447 |
+
|
448 |
+
@torch.inference_mode()
|
449 |
+
def forward_inference(self, tokens: torch.Tensor, start_pos: int, image=None, modal='image'):
|
450 |
+
modal = modal[0] if isinstance(modal, list) else modal
|
451 |
+
_bsz, seqlen = tokens.shape
|
452 |
+
if start_pos == 0:
|
453 |
+
# kv cache will not re-allocate if size is unchanged
|
454 |
+
self._allocate_kv_cache(_bsz)
|
455 |
+
h = self.tok_embeddings(tokens)
|
456 |
+
self.freqs_cis = self.freqs_cis.to(h.device)
|
457 |
+
|
458 |
+
if image is not None:
|
459 |
+
h_bos, h_caption = h[:, :1], h[:, 1:]
|
460 |
+
image_tokens = self.encode_image(image, modal)
|
461 |
+
self.cache_image_words = image_tokens.shape[1]
|
462 |
+
h = torch.cat((h_bos, self.start_tag[modal].repeat(_bsz, 1, 1), image_tokens, self.end_tag[modal].repeat(_bsz, 1, 1), h_caption), dim=1)
|
463 |
+
seqlen = h.shape[1]
|
464 |
+
freqs_cis = self.freqs_cis[0: seqlen]
|
465 |
+
else:
|
466 |
+
if start_pos == 0:
|
467 |
+
self.cache_image_words = 0
|
468 |
+
freqs_cis = self.freqs_cis[0: seqlen]
|
469 |
+
else:
|
470 |
+
# if image was not None when start_pos=0,
|
471 |
+
# the offset should be added to start_pos within later forward_inference calls
|
472 |
+
start_pos = start_pos + self.cache_image_words
|
473 |
+
freqs_cis = self.freqs_cis[start_pos: start_pos + seqlen]
|
474 |
+
|
475 |
+
# freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
|
476 |
+
|
477 |
+
mask = None
|
478 |
+
if seqlen > 1:
|
479 |
+
mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)
|
480 |
+
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
|
481 |
+
|
482 |
+
for layer in self.layers:
|
483 |
+
h = layer(h, start_pos, freqs_cis, mask)
|
484 |
+
h = self.norm(h)
|
485 |
+
output = self.output(h[:, -1, :]) # only compute last logits
|
486 |
+
return output.float()
|
487 |
+
|
488 |
+
def _allocate_kv_cache(self, max_batch_size: int) -> None:
|
489 |
+
for layer in self.layers:
|
490 |
+
layer.attention.allocate_kv_cache(
|
491 |
+
max_batch_size, self.params.max_seq_len)
|
492 |
+
|
493 |
+
def _destroy_kv_cache(self) -> None:
|
494 |
+
for layer in self.layers:
|
495 |
+
layer.attention.destroy_kv_cache()
|
model/__init__.py
ADDED
File without changes
|
model/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (154 Bytes). View file
|
|
model/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (152 Bytes). View file
|
|
model/__pycache__/components.cpython-39.pyc
ADDED
Binary file (2.15 kB). View file
|
|
model/__pycache__/meta.cpython-310.pyc
ADDED
Binary file (5.52 kB). View file
|
|
model/__pycache__/meta.cpython-39.pyc
ADDED
Binary file (5.52 kB). View file
|
|