csuhan commited on
Commit
8b54513
1 Parent(s): 8cc8ce8

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +10 -0
  2. README.md +90 -5
  3. app.py +263 -268
  4. config/llama2/7B.json +1 -0
  5. config/llama2/tokenizer.model +3 -0
  6. data/__pycache__/conversation_lib.cpython-310.pyc +0 -0
  7. data/__pycache__/conversation_lib.cpython-39.pyc +0 -0
  8. data/__pycache__/fintune_dataset.cpython-310.pyc +0 -0
  9. data/__pycache__/fintune_dataset.cpython-39.pyc +0 -0
  10. data/__pycache__/imu_utils.cpython-310.pyc +0 -0
  11. data/__pycache__/imu_utils.cpython-39.pyc +0 -0
  12. data/__pycache__/video_utils.cpython-310.pyc +0 -0
  13. data/__pycache__/video_utils.cpython-39.pyc +0 -0
  14. data/conversation_lib.py +369 -0
  15. data/fintune_dataset.py +449 -0
  16. data/imu_utils.py +257 -0
  17. data/video_utils.py +204 -0
  18. demos/multi_turn_mm.py +300 -0
  19. lib/__pycache__/point_utils.cpython-310.pyc +0 -0
  20. lib/point_utils.py +191 -0
  21. lib/pointnet2/pointnet2_modules.py +160 -0
  22. lib/pointnet2/pointnet2_utils.py +290 -0
  23. lib/pointnet2/pytorch_utils.py +236 -0
  24. lib/pointnet2/setup.py +23 -0
  25. lib/pointnet2/src/ball_query.cpp +24 -0
  26. lib/pointnet2/src/ball_query_gpu.cu +67 -0
  27. lib/pointnet2/src/ball_query_gpu.h +15 -0
  28. lib/pointnet2/src/cuda_utils.h +15 -0
  29. lib/pointnet2/src/group_points.cpp +34 -0
  30. lib/pointnet2/src/group_points_gpu.cu +86 -0
  31. lib/pointnet2/src/group_points_gpu.h +22 -0
  32. lib/pointnet2/src/interpolate.cpp +53 -0
  33. lib/pointnet2/src/interpolate_gpu.cu +161 -0
  34. lib/pointnet2/src/interpolate_gpu.h +30 -0
  35. lib/pointnet2/src/pointnet2_api.cpp +24 -0
  36. lib/pointnet2/src/sampling.cpp +45 -0
  37. lib/pointnet2/src/sampling_gpu.cu +253 -0
  38. lib/pointnet2/src/sampling_gpu.h +29 -0
  39. model/LLM/__init__.py +1 -0
  40. model/LLM/__pycache__/__init__.cpython-310.pyc +0 -0
  41. model/LLM/__pycache__/__init__.cpython-39.pyc +0 -0
  42. model/LLM/__pycache__/onellm.cpython-310.pyc +0 -0
  43. model/LLM/__pycache__/onellm.cpython-39.pyc +0 -0
  44. model/LLM/onellm.py +495 -0
  45. model/__init__.py +0 -0
  46. model/__pycache__/__init__.cpython-310.pyc +0 -0
  47. model/__pycache__/__init__.cpython-39.pyc +0 -0
  48. model/__pycache__/components.cpython-39.pyc +0 -0
  49. model/__pycache__/meta.cpython-310.pyc +0 -0
  50. model/__pycache__/meta.cpython-39.pyc +0 -0
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ *.pyc
3
+ *.egg-info
4
+ dist
5
+
6
+ output
7
+ output_dir
8
+ *.pth
9
+ *.log
10
+ weights
README.md CHANGED
@@ -1,15 +1,100 @@
1
  ---
2
- title: LLaMA Adapter V2
3
  emoji: 🚀
4
  colorFrom: red
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 3.23.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- ### LLaMA-Adapter
13
- The official demo for LLaMA-Adapter V2.
14
- Please refer to our [arXiv paper](https://arxiv.org/abs/2303.16199) and [github](https://github.com/ZrrSkywalker/LLaMA-Adapter) for more details.
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: OneLLM
3
  emoji: 🚀
4
  colorFrom: red
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 4.7.1
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ # OneLLM: One Framework to Align All Modalities with Language
 
 
13
 
14
+ [[Project Page](https://onellm.csuhan.com)] [[Paper](#)] [[Web Demo](https://huggingface.co/spaces/csuhan/OneLLM)]
15
+
16
+ Authors: [Jiaming Han](), [Kaixiong Gong](), [Yiyuan Zhang](), [Jiaqi Wang](), [Kaipeng Zhang](), [Dahua Lin](), [Yu Qiao](), [Peng Gao](), [Xiangyu Yue]().
17
+
18
+ ## News
19
+
20
+ - **2023.12.01** Release model weights and inference code.
21
+
22
+ ## Contents
23
+
24
+ - [Install](#install)
25
+ - [Models](#models)
26
+ - [Demo](#demo)
27
+
28
+ <!-- - [Evaluation](#evaluation) -->
29
+
30
+ <!-- - [Training](#training) -->
31
+
32
+ ### TODO
33
+
34
+ - [ ] Data
35
+ - [ ] Evaluation
36
+ - [ ] Training
37
+
38
+ ### Install
39
+
40
+ 1. Clone the repo into a local folder.
41
+
42
+ ```bash
43
+ git clone https://github.com/csuhan/OneLLM
44
+
45
+ cd OneLLM
46
+ ```
47
+
48
+ 2. Install packages.
49
+
50
+ ```bash
51
+ conda create -n onellm python=3.9 -y
52
+ conda activate onellm
53
+
54
+ pip install -r requirements.txt
55
+
56
+ # install pointnet
57
+ cd lib/pointnet2
58
+ python setup.py install
59
+ ```
60
+
61
+ 3. Install Apex. (Optional)
62
+
63
+ ```bash
64
+ git clone https://github.com/NVIDIA/apex
65
+ cd apex
66
+ pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
67
+ ```
68
+
69
+ ### Models
70
+
71
+ We provide a preview model at: [csuhan/OneLLM-7B](https://huggingface.co/csuhan/OneLLM-7B).
72
+
73
+ ### Demo
74
+
75
+ **Huggingface Demo:** [csuhan/OneLLM](https://huggingface.co/spaces/csuhan/OneLLM).
76
+
77
+ **Local Demo:** Assume you have downloaded the weights to ${WEIGHTS_DIR}. Then run the following command to start a gradio demo locally.
78
+
79
+ ```bash
80
+ python demos/multi_turn_mm.py --gpu_ids 0 --tokenizer_path config/llama2/tokenizer.model --llama_config config/llama2/7B.json --pretrained_path ${WEIGHTS_DIR}/consolidated.00-of-01.pth
81
+ ```
82
+
83
+ <!-- ### Evaluation -->
84
+
85
+ <!-- ### Training -->
86
+
87
+ ## Citation
88
+
89
+ ```
90
+ @article{han2023onellm,
91
+ title={OneLLM: One Framework to Align All Modalities with Language},
92
+ author={Han, Jiaming and Gong, Kaixiong and Zhang, Yiyuan and Wang, Jiaqi and Zhang, Kaipeng and Lin, Dahua and Qiao, Yu and Gao, Peng and Yue, Xiangyu},
93
+ journal={arXiv preprint arXiv:xxxx},
94
+ year={2023}
95
+ }
96
+ ```
97
+
98
+ ## Acknowledgement
99
+
100
+ [LLaMA](https://github.com/facebookresearch/llama), [LLaMA-Adapter](https://github.com/OpenGVLab/LLaMA-Adapter), [LLaMA2-Accessory](https://github.com/Alpha-VLLM/LLaMA2-Accessory), [Meta-Transformer](https://github.com/invictus717/MetaTransformer), [ChatBridge](https://github.com/joez17/ChatBridge)
app.py CHANGED
@@ -1,277 +1,272 @@
1
- import json
2
- import os
3
- import glob
4
  import sys
5
- import time
6
- from pathlib import Path
7
- from typing import Tuple
 
 
 
8
 
9
- from huggingface_hub import hf_hub_download
10
- from PIL import Image
11
- import gradio as gr
12
  import torch
13
- from fairscale.nn.model_parallel.initialize import initialize_model_parallel
14
-
15
- from llama import LLaMA, ModelArgs, Tokenizer, Transformer, VisionModel
16
-
17
- os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
18
-
19
- PROMPT_DICT = {
20
- "prompt_input": (
21
- "Below is an instruction that describes a task, paired with an input that provides further context. "
22
- "Write a response that appropriately completes the request.\n\n"
23
- "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
24
- ),
25
- "prompt_no_input": (
26
- "Below is an instruction that describes a task. "
27
- "Write a response that appropriately completes the request.\n\n"
28
- "### Instruction:\n{instruction}\n\n### Response:"
29
- ),
30
- }
31
-
32
-
33
- def setup_model_parallel() -> Tuple[int, int]:
34
- os.environ['RANK'] = '0'
35
- os.environ['WORLD_SIZE'] = '1'
36
- os.environ['MP'] = '1'
37
- os.environ['MASTER_ADDR'] = '127.0.0.1'
38
- os.environ['MASTER_PORT'] = '2223'
39
- local_rank = int(os.environ.get("LOCAL_RANK", -1))
40
- world_size = int(os.environ.get("WORLD_SIZE", -1))
41
-
42
- torch.distributed.init_process_group("nccl")
43
- initialize_model_parallel(world_size)
44
- torch.cuda.set_device(local_rank)
45
-
46
- # seed must be the same in all processes
47
- torch.manual_seed(1)
48
- return local_rank, world_size
49
-
50
-
51
- def load(
52
- ckpt0_path: str,
53
- ckpt1_path: str,
54
- param_path: str,
55
- tokenizer_path: str,
56
- instruct_adapter_path: str,
57
- caption_adapter_path: str,
58
- local_rank: int,
59
- world_size: int,
60
- max_seq_len: int,
61
- max_batch_size: int,
62
- ) -> LLaMA:
63
- start_time = time.time()
64
- print("Loading")
65
- instruct_adapter_checkpoint = torch.load(
66
- instruct_adapter_path, map_location="cpu")
67
- caption_adapter_checkpoint = torch.load(
68
- caption_adapter_path, map_location="cpu")
69
- with open(param_path, "r") as f:
70
- params = json.loads(f.read())
71
-
72
- model_args: ModelArgs = ModelArgs(
73
- max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
74
- )
75
- model_args.adapter_layer = int(
76
- instruct_adapter_checkpoint['adapter_query.weight'].shape[0] / model_args.adapter_len)
77
- model_args.cap_adapter_layer = int(
78
- caption_adapter_checkpoint['cap_adapter_query.weight'].shape[0] / model_args.cap_adapter_len)
79
-
80
- tokenizer = Tokenizer(model_path=tokenizer_path)
81
- model_args.vocab_size = tokenizer.n_words
82
- torch.set_default_tensor_type(torch.cuda.HalfTensor)
83
- model = Transformer(model_args)
84
-
85
- # To reduce memory usuage
86
- ckpt0 = torch.load(ckpt0_path, map_location='cuda')
87
- model.load_state_dict(ckpt0, strict=False)
88
- del ckpt0
89
- torch.cuda.empty_cache()
90
-
91
- ckpt1 = torch.load(ckpt1_path, map_location='cuda')
92
- model.load_state_dict(ckpt1, strict=False)
93
- del ckpt1
94
- torch.cuda.empty_cache()
95
-
96
- vision_model = VisionModel(model_args)
97
-
98
- torch.set_default_tensor_type(torch.FloatTensor)
99
- model.load_state_dict(instruct_adapter_checkpoint, strict=False)
100
- model.load_state_dict(caption_adapter_checkpoint, strict=False)
101
- vision_model.load_state_dict(caption_adapter_checkpoint, strict=False)
102
-
103
- generator = LLaMA(model, tokenizer, vision_model)
104
- print(f"Loaded in {time.time() - start_time:.2f} seconds")
105
- return generator
106
-
107
-
108
- def instruct_generate(
109
- instruct: str,
110
- input: str = 'none',
111
- max_gen_len=512,
112
- temperature: float = 0.1,
113
- top_p: float = 0.75,
114
- ):
115
- if input == 'none':
116
- prompt = PROMPT_DICT['prompt_no_input'].format_map(
117
- {'instruction': instruct, 'input': ''})
118
- else:
119
- prompt = PROMPT_DICT['prompt_input'].format_map(
120
- {'instruction': instruct, 'input': input})
121
-
122
- results = generator.generate(
123
- [prompt], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p
124
- )
125
- result = results[0].strip()
126
- print(result)
127
- return result
128
-
129
-
130
- def caption_generate(
131
- img: str,
132
- max_gen_len=512,
133
- temperature: float = 0.1,
134
- top_p: float = 0.75,
135
- ):
136
- imgs = [Image.open(img).convert('RGB')]
137
- prompts = ["Generate caption of this image :",] * len(imgs)
138
-
139
- results = generator.generate(
140
- prompts, imgs=imgs, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p
141
- )
142
- result = results[0].strip()
143
- print(result)
144
- return result
145
-
146
-
147
- def download_llama_adapter(instruct_adapter_path, caption_adapter_path):
148
- if not os.path.exists(instruct_adapter_path):
149
- os.system(
150
- f"wget -q -O {instruct_adapter_path} https://github.com/ZrrSkywalker/LLaMA-Adapter/releases/download/v.1.0.0/llama_adapter_len10_layer30_release.pth")
151
-
152
- if not os.path.exists(caption_adapter_path):
153
- os.system(
154
- f"wget -q -O {caption_adapter_path} https://github.com/ZrrSkywalker/LLaMA-Adapter/releases/download/v.1.0.0/llama_adapter_len10_layer30_caption_vit_l.pth")
155
-
156
-
157
- # ckpt_path = "/data1/llma/7B/consolidated.00.pth"
158
- # param_path = "/data1/llma/7B/params.json"
159
- # tokenizer_path = "/data1/llma/tokenizer.model"
160
- ckpt0_path = hf_hub_download(
161
- repo_id="csuhan/llama_storage", filename="consolidated.00_part0.pth")
162
- ckpt1_path = hf_hub_download(
163
- repo_id="csuhan/llama_storage", filename="consolidated.00_part1.pth")
164
- param_path = hf_hub_download(
165
- repo_id="nyanko7/LLaMA-7B", filename="params.json")
166
- tokenizer_path = hf_hub_download(
167
- repo_id="nyanko7/LLaMA-7B", filename="tokenizer.model")
168
- instruct_adapter_path = "llama_adapter_len10_layer30_release.pth"
169
- caption_adapter_path = "llama_adapter_len10_layer30_caption_vit_l.pth"
170
- max_seq_len = 512
171
- max_batch_size = 1
172
-
173
- # download models
174
- # download_llama_adapter(instruct_adapter_path, caption_adapter_path)
175
-
176
- local_rank, world_size = setup_model_parallel()
177
- if local_rank > 0:
178
- sys.stdout = open(os.devnull, "w")
179
-
180
- generator = load(
181
- ckpt0_path, ckpt1_path, param_path, tokenizer_path, instruct_adapter_path, caption_adapter_path, local_rank, world_size, max_seq_len, max_batch_size
182
- )
183
-
184
-
185
- def create_instruct_demo():
186
- with gr.Blocks() as instruct_demo:
187
- with gr.Row():
188
- with gr.Column():
189
- instruction = gr.Textbox(lines=2, label="Instruction")
190
- input = gr.Textbox(
191
- lines=2, label="Context input", placeholder='none')
192
- max_len = gr.Slider(minimum=1, maximum=512,
193
- value=128, label="Max length")
194
- with gr.Accordion(label='Advanced options', open=False):
195
- temp = gr.Slider(minimum=0, maximum=1,
196
- value=0.1, label="Temperature")
197
- top_p = gr.Slider(minimum=0, maximum=1,
198
- value=0.75, label="Top p")
199
-
200
- run_botton = gr.Button("Run")
201
-
202
- with gr.Column():
203
- outputs = gr.Textbox(lines=10, label="Output")
204
-
205
- inputs = [instruction, input, max_len, temp, top_p]
206
-
207
- examples = [
208
- "Tell me about alpacas.",
209
- "Write a Python program that prints the first 10 Fibonacci numbers.",
210
- "Write a conversation between the sun and pluto.",
211
- "Write a theory to explain why cat never existed",
212
- ]
213
- examples = [
214
- [x, "none", 128, 0.1, 0.75]
215
- for x in examples]
216
-
217
- gr.Examples(
218
- examples=examples,
219
- inputs=inputs,
220
- outputs=outputs,
221
- fn=instruct_generate,
222
- cache_examples=os.getenv('SYSTEM') == 'spaces'
223
- )
224
- run_botton.click(fn=instruct_generate, inputs=inputs, outputs=outputs)
225
- return instruct_demo
226
 
 
227
 
228
- def create_caption_demo():
229
- with gr.Blocks() as instruct_demo:
230
- with gr.Row():
231
- with gr.Column():
232
- img = gr.Image(label='Input', type='filepath')
233
- max_len = gr.Slider(minimum=1, maximum=512,
234
- value=64, label="Max length")
235
- with gr.Accordion(label='Advanced options', open=False):
236
- temp = gr.Slider(minimum=0, maximum=1,
237
- value=0.1, label="Temperature")
238
- top_p = gr.Slider(minimum=0, maximum=1,
239
- value=0.75, label="Top p")
240
-
241
- run_botton = gr.Button("Run")
242
-
243
- with gr.Column():
244
- outputs = gr.Textbox(lines=10, label="Output")
245
-
246
- inputs = [img, max_len, temp, top_p]
247
-
248
- examples = glob.glob("caption_demo/*.jpg")
249
- examples = [
250
- [x, 64, 0.1, 0.75]
251
- for x in examples]
252
-
253
- gr.Examples(
254
- examples=examples,
255
- inputs=inputs,
256
- outputs=outputs,
257
- fn=caption_generate,
258
- cache_examples=os.getenv('SYSTEM') == 'spaces'
259
- )
260
- run_botton.click(fn=caption_generate, inputs=inputs, outputs=outputs)
261
- return instruct_demo
262
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
- description = """
265
- # LLaMA-Adapter🚀
266
- The official demo for **LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention**.
267
- Please refer to our [arXiv paper](https://arxiv.org/abs/2303.16199) and [github](https://github.com/ZrrSkywalker/LLaMA-Adapter) for more details.
268
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
- with gr.Blocks(css='style.css') as demo:
271
- gr.Markdown(description)
272
- with gr.TabItem("Instruction-Following"):
273
- create_instruct_demo()
274
- with gr.TabItem("Image Captioning"):
275
- create_caption_demo()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
- demo.queue(api_open=True, concurrency_count=1).launch()
 
 
 
 
1
  import sys
2
+ import os
3
+
4
+ import argparse
5
+ import multiprocessing as mp
6
+ import numpy as np
7
+ from typing import List, Optional
8
 
 
 
 
9
  import torch
10
+ import torch.distributed as dist
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ from fairscale.nn.model_parallel import initialize as fs_init
13
 
14
+ import gradio as gr
15
+ from util.misc import setup_for_distributed
16
+ from util.misc import default_tensor_type
17
+ from model.meta import MetaModel
18
+ from data.conversation_lib import conv_templates, SeparatorStyle
19
+ from PIL import Image
20
+ import torchvision.transforms as transforms
21
+ from data.fintune_dataset import make_audio_features
22
+ from data import video_utils
23
+ from dataclasses import dataclass
24
+ from huggingface_hub import hf_hub_download
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ T_random_resized_crop = transforms.Compose([
27
+ transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=3,
28
+ antialias=None), # 3 is bicubic
29
+ transforms.ToTensor(),
30
+ transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
31
+
32
+
33
+ def load_audio(audio_path):
34
+ fbank = make_audio_features(audio_path, mel_bins=128)
35
+ fbank = fbank.transpose(0, 1)[None] #[1, 128, 1024]
36
+ return fbank
37
+
38
+ def load_video(video_path):
39
+ video_feats = video_utils.load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5)
40
+ return video_feats[:, :, 0]
41
+
42
+
43
+ def model_worker(
44
+ rank: int, args: argparse.Namespace, barrier: mp.Barrier,
45
+ request_queue: mp.Queue, response_queue: Optional[mp.Queue] = None,
46
+ ) -> None:
47
+ """
48
+ The worker function that manipulates the GPU to run the inference.
49
+ Exact n_gpu workers are started, with each one operating on a separate GPU.
50
+
51
+ Args:
52
+ rank (int): Distributed rank of the worker.
53
+ args (argparse.Namespace): All command line arguments.
54
+ barrier (multiprocessing.Barrier): A barrier used to delay the start
55
+ of Web UI to be after the start of the model.
56
+ """
57
+
58
+ world_size = len(args.gpu_ids)
59
+ gpu_id = args.gpu_ids[rank]
60
+ dist.init_process_group(
61
+ backend="nccl", rank=rank, world_size=world_size,
62
+ init_method=f"tcp://{args.master_addr}:{args.master_port}",
63
+ )
64
+ print(f"| distributed init on worker {rank}/{world_size}. "
65
+ f"using gpu: {gpu_id}")
66
+ fs_init.initialize_model_parallel(world_size)
67
+ torch.cuda.set_device(gpu_id)
68
 
69
+ torch.manual_seed(1)
70
+ np.random.seed(1)
71
+
72
+ # set the print behavior.
73
+ setup_for_distributed(rank == 0)
74
+
75
+ target_dtype = {
76
+ "bf16": torch.bfloat16,
77
+ "fp16": torch.float16
78
+ }[args.dtype]
79
+ with default_tensor_type(dtype=target_dtype, device="cuda"):
80
+ model = MetaModel(args.llama_type, args.llama_config, tokenizer_path=args.tokenizer_path)
81
+ print("Loading pretrained weights ...")
82
+ checkpoint = torch.load(args.pretrained_path, map_location='cpu')
83
+ msg = model.load_state_dict(checkpoint, strict=False)
84
+ print("load result:\n", msg)
85
+ model.cuda()
86
+ model.eval()
87
+ print(f"Model = {str(model)}")
88
+
89
+ barrier.wait()
90
+
91
+ while True:
92
+ img_path, audio_path, video_path, chatbot, max_gen_len, temperature, top_p, modality = request_queue.get()
93
+ if 'image' in modality and img_path is not None:
94
+ image = Image.open(img_path).convert('RGB')
95
+ inputs = T_random_resized_crop(image)
96
+ elif 'video' in modality and video_path is not None:
97
+ inputs = load_video(video_path)
98
+ elif 'audio' in modality and audio_path is not None:
99
+ inputs = load_audio(audio_path)
100
+ else:
101
+ inputs = None
102
+
103
+ if inputs is not None:
104
+ inputs = inputs[None].cuda().to(target_dtype)
105
+
106
+ conv = conv_templates["v1"].copy()
107
+ for user, bot in chatbot:
108
+ conv.append_message(conv.roles[0], user)
109
+ conv.append_message(conv.roles[1], bot)
110
+
111
+ with torch.cuda.amp.autocast(dtype=target_dtype):
112
+ print(conv.get_prompt())
113
+ for stream_response in model.stream_generate(
114
+ conv.get_prompt(), inputs,
115
+ max_gen_len=max_gen_len, temperature=temperature, top_p=top_p,
116
+ modal = modality
117
+ ):
118
+ conv_sep = (
119
+ conv.sep
120
+ if conv.sep_style == SeparatorStyle.SINGLE
121
+ else conv.sep2
122
+ )
123
+ end_pos = stream_response["text"].find(conv_sep)
124
+ if end_pos != -1:
125
+ stream_response["text"] = (
126
+ stream_response['text'][:end_pos].rstrip() + "\n"
127
+ )
128
+ stream_response["end_of_content"] = True
129
+
130
+ # keep a few characters if not end_of_content to avoid sending
131
+ # part of conv_sep before all of it is generated.
132
+ if not stream_response["end_of_content"]:
133
+ if len(stream_response["text"]) < len(conv_sep):
134
+ continue
135
+ stream_response["text"] = (
136
+ stream_response["text"][:-len(conv_sep)]
137
+ )
138
+
139
+ if response_queue is not None:
140
+ response_queue.put(stream_response)
141
+
142
+ if stream_response["end_of_content"]:
143
+ break
144
+
145
+
146
+ def gradio_worker(
147
+ request_queues: List[mp.Queue], response_queue: mp.Queue,
148
+ args: argparse.Namespace, barrier: mp.Barrier,
149
+ ) -> None:
150
+ """
151
+ The gradio worker is responsible for displaying the WebUI and relay the
152
+ requests to model workers. It should be launched only once.
153
+
154
+ Args:
155
+ request_queues (List[mp.Queue]): A list of request queues (one for
156
+ each model worker).
157
+ args (argparse.Namespace): All command line arguments.
158
+ barrier (multiprocessing.Barrier): A barrier used to delay the start
159
+ of Web UI to be after the start of the model.
160
+ """
161
+
162
+ def show_user_input(msg, chatbot):
163
+ return "", chatbot + [[msg, None]]
164
+
165
+ def stream_model_output(img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality):
166
+ for queue in request_queues:
167
+ queue.put((img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality))
168
+ while True:
169
+ content_piece = response_queue.get()
170
+ chatbot[-1][1] = content_piece["text"]
171
+ yield chatbot
172
+ if content_piece["end_of_content"]:
173
+ break
174
+
175
+ def undo(chatbot):
176
+ if len(chatbot) > 0:
177
+ chatbot = chatbot[:-1]
178
+ return chatbot
179
+
180
+ def clear():
181
+ chatbot = []
182
+ msg = ""
183
+ return chatbot, msg
184
+
185
+ CSS ="""
186
+ .contain { display: flex; flex-direction: column; }
187
+ #component-0 { height: 100%; }
188
+ #chatbot { flex-grow: 1; overflow: auto;}
189
+ """
190
+ with gr.Blocks(css=CSS) as demo:
191
+ gr.Markdown("## OneLLM: One Framework to Align All Modalities with Language")
192
+ with gr.Row(equal_height=True):
193
+ with gr.Column(scale=1):
194
+ img_path = gr.Image(label='Image Input', type='filepath')
195
+ video_path = gr.Video(label='Video Input')
196
+ audio_path = gr.Audio(label='Audio Input', type='filepath', sources=['upload'])
197
+ modality = gr.Radio(choices=['image', 'audio', 'video'], value='image', interactive=True, label='Input Modalities')
198
+
199
+ with gr.Column(scale=2):
200
+ chatbot = gr.Chatbot(elem_id="chatbot")
201
+ msg = gr.Textbox()
202
 
203
+ with gr.Row():
204
+ submit_button = gr.Button("Submit", variant="primary")
205
+ undo_button = gr.Button("Undo")
206
+ clear_button = gr.ClearButton([chatbot, msg, img_path, audio_path, video_path, modality])
207
+ with gr.Row():
208
+ max_gen_len = gr.Slider(
209
+ minimum=1, maximum=args.model_max_seq_len // 2,
210
+ value=args.model_max_seq_len // 2, interactive=True,
211
+ label="Single-turn max response length",
212
+ )
213
+ gen_t = gr.Slider(
214
+ minimum=0, maximum=1, value=0.1, interactive=True,
215
+ label="Temperature",
216
+ )
217
+ top_p = gr.Slider(
218
+ minimum=0, maximum=1, value=0.75, interactive=True,
219
+ label="Top-p",
220
+ )
221
+ msg.submit(
222
+ show_user_input, [msg, chatbot], [msg, chatbot],
223
+ ).then(
224
+ stream_model_output, [img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot,
225
+ )
226
+ submit_button.click(
227
+ show_user_input, [msg, chatbot], [msg, chatbot],
228
+ ).then(
229
+ stream_model_output, [img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot,
230
+ )
231
+ undo_button.click(undo, chatbot, chatbot)
232
+ # img_path.change(clear, [], [chatbot, msg])
233
+ barrier.wait()
234
+ demo.queue(api_open=True).launch(share=True, max_threads=1)
235
+
236
+
237
+ @dataclass
238
+ class DemoConfig:
239
+ gpu_ids = [0]
240
+ tokenizer_path = "config/llama2/tokenizer.model"
241
+ llama_type = "onellm"
242
+ llama_config = "config/llama2/7B.json"
243
+ model_max_seq_len = 2048
244
+ # pretrained_path = "weights/7B_2048/consolidated.00-of-01.pth"
245
+ pretrained_path = hf_hub_download(repo_id="csuhan/OneLLM-7B", filename="consolidated.00-of-01.pth")
246
+ master_port = 23861
247
+ master_addr = "127.0.0.1"
248
+ dtype = "fp16"
249
+
250
+ if __name__ == "__main__":
251
+ args = DemoConfig()
252
+ # using the default "fork" method messes up some imported libs (e.g.,
253
+ # pandas)
254
+ mp.set_start_method("spawn")
255
+
256
+ # setup the queues and start the model workers
257
+ request_queues = []
258
+ response_queue = mp.Queue()
259
+ worker_processes = []
260
+ barrier = mp.Barrier(len(args.gpu_ids) + 1)
261
+ for rank, gpu_id in enumerate(args.gpu_ids):
262
+ request_queue = mp.Queue()
263
+ rank_response_queue = response_queue if rank == 0 else None
264
+ process = mp.Process(
265
+ target=model_worker,
266
+ args=(rank, args, barrier, request_queue, rank_response_queue),
267
+ )
268
+ process.start()
269
+ worker_processes.append(process)
270
+ request_queues.append(request_queue)
271
 
272
+ gradio_worker(request_queues, response_queue, args, barrier)
config/llama2/7B.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-05, "vocab_size": -1}
config/llama2/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
data/__pycache__/conversation_lib.cpython-310.pyc ADDED
Binary file (9.14 kB). View file
 
data/__pycache__/conversation_lib.cpython-39.pyc ADDED
Binary file (9.15 kB). View file
 
data/__pycache__/fintune_dataset.cpython-310.pyc ADDED
Binary file (14.2 kB). View file
 
data/__pycache__/fintune_dataset.cpython-39.pyc ADDED
Binary file (14.2 kB). View file
 
data/__pycache__/imu_utils.cpython-310.pyc ADDED
Binary file (6.71 kB). View file
 
data/__pycache__/imu_utils.cpython-39.pyc ADDED
Binary file (6.71 kB). View file
 
data/__pycache__/video_utils.cpython-310.pyc ADDED
Binary file (6.53 kB). View file
 
data/__pycache__/video_utils.cpython-39.pyc ADDED
Binary file (6.51 kB). View file
 
data/conversation_lib.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+ SINGLE = auto()
9
+ TWO = auto()
10
+ MPT = auto()
11
+
12
+
13
+ @dataclasses.dataclass
14
+ class Conversation:
15
+ """A class that keeps all conversation history."""
16
+ system: str
17
+ roles: List[str]
18
+ messages: List[List[str]]
19
+ offset: int
20
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
21
+ sep: str = "###"
22
+ sep2: str = None
23
+ version: str = "Unknown"
24
+
25
+ skip_next: bool = False
26
+
27
+ def get_prompt(self):
28
+ if self.sep_style == SeparatorStyle.SINGLE:
29
+ ret = self.system + '\n\n' + self.sep
30
+ for role, message in self.messages:
31
+ if message:
32
+ if type(message) is tuple:
33
+ message, _, _ = message
34
+ ret += role + ": " + message + '\n' + self.sep
35
+ else:
36
+ ret += role + ":"
37
+ return ret
38
+ elif self.sep_style == SeparatorStyle.TWO:
39
+ seps = [self.sep, self.sep2]
40
+ ret = self.system + seps[0]
41
+ for i, (role, message) in enumerate(self.messages):
42
+ if message:
43
+ if type(message) is tuple:
44
+ message, _, _ = message
45
+ ret += role + ": " + message + seps[i % 2]
46
+ else:
47
+ ret += role + ":"
48
+ return ret
49
+ if self.sep_style == SeparatorStyle.MPT:
50
+ ret = self.system + self.sep
51
+ for role, message in self.messages:
52
+ if message:
53
+ if type(message) is tuple:
54
+ message, _, _ = message
55
+ ret += role + message + self.sep
56
+ else:
57
+ ret += role
58
+ return ret
59
+ else:
60
+ raise ValueError(f"Invalid style: {self.sep_style}")
61
+
62
+ def append_message(self, role, message):
63
+ self.messages.append([role, message])
64
+
65
+ def get_images(self, return_pil=False):
66
+ images = []
67
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
68
+ if i % 2 == 0:
69
+ if type(msg) is tuple:
70
+ import base64
71
+ from io import BytesIO
72
+ from PIL import Image
73
+ msg, image, image_process_mode = msg
74
+ if image_process_mode == "Pad":
75
+ def expand2square(pil_img, background_color=(122, 116, 104)):
76
+ width, height = pil_img.size
77
+ if width == height:
78
+ return pil_img
79
+ elif width > height:
80
+ result = Image.new(pil_img.mode, (width, width), background_color)
81
+ result.paste(pil_img, (0, (width - height) // 2))
82
+ return result
83
+ else:
84
+ result = Image.new(pil_img.mode, (height, height), background_color)
85
+ result.paste(pil_img, ((height - width) // 2, 0))
86
+ return result
87
+
88
+ image = expand2square(image)
89
+ elif image_process_mode == "Crop":
90
+ pass
91
+ elif image_process_mode == "Resize":
92
+ image = image.resize((224, 224))
93
+ else:
94
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
95
+ max_hw, min_hw = max(image.size), min(image.size)
96
+ aspect_ratio = max_hw / min_hw
97
+ max_len, min_len = 800, 400
98
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
99
+ longest_edge = int(shortest_edge * aspect_ratio)
100
+ W, H = image.size
101
+ if H > W:
102
+ H, W = longest_edge, shortest_edge
103
+ else:
104
+ H, W = shortest_edge, longest_edge
105
+ image = image.resize((W, H))
106
+ if return_pil:
107
+ images.append(image)
108
+ else:
109
+ buffered = BytesIO()
110
+ image.save(buffered, format="JPEG")
111
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
112
+ images.append(img_b64_str)
113
+ return images
114
+
115
+ def to_gradio_chatbot(self):
116
+ ret = []
117
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
118
+ if i % 2 == 0:
119
+ if type(msg) is tuple:
120
+ import base64
121
+ from io import BytesIO
122
+ msg, image, image_process_mode = msg
123
+ max_hw, min_hw = max(image.size), min(image.size)
124
+ aspect_ratio = max_hw / min_hw
125
+ max_len, min_len = 800, 400
126
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
127
+ longest_edge = int(shortest_edge * aspect_ratio)
128
+ W, H = image.size
129
+ if H > W:
130
+ H, W = longest_edge, shortest_edge
131
+ else:
132
+ H, W = shortest_edge, longest_edge
133
+ image = image.resize((W, H))
134
+ # image = image.resize((224, 224))
135
+ buffered = BytesIO()
136
+ image.save(buffered, format="JPEG")
137
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
138
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
139
+ msg = msg.replace('<image>', img_str)
140
+ ret.append([msg, None])
141
+ else:
142
+ ret[-1][-1] = msg
143
+ return ret
144
+
145
+ def copy(self):
146
+ return Conversation(
147
+ system=self.system,
148
+ roles=self.roles,
149
+ messages=[[x, y] for x, y in self.messages],
150
+ offset=self.offset,
151
+ sep_style=self.sep_style,
152
+ sep=self.sep,
153
+ sep2=self.sep2)
154
+
155
+ def dict(self):
156
+ if len(self.get_images()) > 0:
157
+ return {
158
+ "system": self.system,
159
+ "roles": self.roles,
160
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
161
+ "offset": self.offset,
162
+ "sep": self.sep,
163
+ "sep2": self.sep2,
164
+ }
165
+ return {
166
+ "system": self.system,
167
+ "roles": self.roles,
168
+ "messages": self.messages,
169
+ "offset": self.offset,
170
+ "sep": self.sep,
171
+ "sep2": self.sep2,
172
+ }
173
+
174
+
175
+ conv_v1 = Conversation(
176
+ system="A chat between a curious human and an artificial intelligence assistant. "
177
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
178
+ roles=("Human", "Assistant"),
179
+ messages=(
180
+ ("Human", "Give three tips for staying healthy."),
181
+ ("Assistant",
182
+ "Sure, here are three tips for staying healthy:\n"
183
+ "1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
184
+ "It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
185
+ "and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or "
186
+ "75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening "
187
+ "activities at least two days per week.\n"
188
+ "2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, "
189
+ "vegetables, whole grains, lean proteins, and healthy fats can help support "
190
+ "your overall health. Try to limit your intake of processed and high-sugar foods, "
191
+ "and aim to drink plenty of water throughout the day.\n"
192
+ "3. Get enough sleep: Getting enough quality sleep is essential for your physical "
193
+ "and mental health. Adults should aim for seven to nine hours of sleep per night. "
194
+ "Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
195
+ "help improve the quality of your sleep.")
196
+ ),
197
+ offset=2,
198
+ sep_style=SeparatorStyle.SINGLE,
199
+ sep="###",
200
+ )
201
+
202
+ conv_v1_2 = Conversation(
203
+ system="A chat between a curious human and an artificial intelligence assistant. "
204
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
205
+ roles=("Human", "Assistant"),
206
+ messages=(),
207
+
208
+ # (
209
+ # ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
210
+ # ("Assistant",
211
+ # "Renewable energy sources are those that can be replenished naturally in a relatively "
212
+ # "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
213
+ # "Non-renewable energy sources, on the other hand, are finite and will eventually be "
214
+ # "depleted, such as coal, oil, and natural gas. Here are some key differences between "
215
+ # "renewable and non-renewable energy sources:\n"
216
+ # "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
217
+ # "energy sources are finite and will eventually run out.\n"
218
+ # "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
219
+ # "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
220
+ # "and other negative effects.\n"
221
+ # "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
222
+ # "have lower operational costs than non-renewable sources.\n"
223
+ # "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
224
+ # "locations than non-renewable sources.\n"
225
+ # "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
226
+ # "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
227
+ # "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
228
+ # "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
229
+ # )
230
+ offset = 2,
231
+ sep_style = SeparatorStyle.SINGLE,
232
+ sep = "###",
233
+ )
234
+
235
+ conv_vicuna_v1_1 = Conversation(
236
+ system="A chat between a curious user and an artificial intelligence assistant. "
237
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
238
+ roles=("USER", "ASSISTANT"),
239
+ version="v1",
240
+ messages=(),
241
+ offset=0,
242
+ sep_style=SeparatorStyle.TWO,
243
+ sep=" ",
244
+ sep2="</s>",
245
+ )
246
+
247
+ conv_mpt = Conversation(
248
+ system="""<|im_start|>system
249
+ - You are a helpful language and vision assistant.
250
+ - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
251
+ - You should follow the instructions carefully and explain your answers in detail.""",
252
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
253
+ version="mpt",
254
+ messages=(),
255
+ offset=0,
256
+ sep_style=SeparatorStyle.MPT,
257
+ sep="<|im_end|>",
258
+ )
259
+
260
+ conv_mpt_text = Conversation(
261
+ system="""<|im_start|>system
262
+ - You are a helpful assistant chatbot trained by MosaicML.
263
+ - You answer questions.
264
+ - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
265
+ - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""",
266
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
267
+ version="mpt",
268
+ messages=(),
269
+ offset=0,
270
+ sep_style=SeparatorStyle.MPT,
271
+ sep="<|im_end|>",
272
+ )
273
+
274
+ conv_bair_v1 = Conversation(
275
+ system="BEGINNING OF CONVERSATION:",
276
+ roles=("USER", "GPT"),
277
+ messages=(),
278
+ offset=0,
279
+ sep_style=SeparatorStyle.TWO,
280
+ sep=" ",
281
+ sep2="</s>",
282
+ )
283
+
284
+ simple_conv = Conversation(
285
+ system="A chat between a curious human and an artificial intelligence assistant. "
286
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
287
+ roles=("Human", "Assistant"),
288
+ messages=(
289
+ ("Human", "Hi!"),
290
+ ("Assistant", "Hi there! How can I help you today?")
291
+ ),
292
+ offset=2,
293
+ sep_style=SeparatorStyle.SINGLE,
294
+ sep="###",
295
+ )
296
+
297
+ simple_conv_multimodal = Conversation(
298
+ system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
299
+ "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
300
+ "Follow the instructions carefully and explain your answers in detail.",
301
+ roles=("Human", "Assistant"),
302
+ messages=(
303
+ ("Human", "Hi!"),
304
+ ("Assistant", "Hi there! How can I help you today?\n")
305
+ ),
306
+ offset=2,
307
+ sep_style=SeparatorStyle.SINGLE,
308
+ sep="###",
309
+ )
310
+
311
+ simple_conv_mpt_multimodal = Conversation(
312
+ system="""<|im_start|>system
313
+ - You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab.
314
+ - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
315
+ - You should follow the instructions carefully and explain your answers in detail.""",
316
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
317
+ version="mpt",
318
+ messages=(),
319
+ offset=0,
320
+ sep_style=SeparatorStyle.MPT,
321
+ sep="<|im_end|>",
322
+ )
323
+
324
+ simple_conv_legacy = Conversation(
325
+ system="You are LLaVA, a large language model trained by UW Madison WAIV Lab."
326
+ "You are designed to assist human with a variety of tasks using natural language."
327
+ "Follow the instructions carefully.",
328
+ roles=("Human", "Assistant"),
329
+ messages=(
330
+ ("Human", "Hi!\n\n### Response:"),
331
+ ("Assistant", "Hi there! How can I help you today?\n")
332
+ ),
333
+ offset=2,
334
+ sep_style=SeparatorStyle.SINGLE,
335
+ sep="###",
336
+ )
337
+
338
+ conv_llava_v1 = Conversation(
339
+ system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
340
+ "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
341
+ "Follow the instructions carefully and explain your answers in detail.",
342
+ roles=("USER", "ASSISTANT"),
343
+ version="v1",
344
+ messages=(),
345
+ offset=0,
346
+ sep_style=SeparatorStyle.TWO,
347
+ sep=" ",
348
+ sep2="</s>",
349
+ )
350
+
351
+ default_conversation = conv_v1_2
352
+ conv_templates = {
353
+ "default": conv_v1_2,
354
+ "simple": simple_conv,
355
+ "simple_legacy": simple_conv_legacy,
356
+ "multimodal": simple_conv_multimodal,
357
+ "mpt_multimodal": simple_conv_mpt_multimodal,
358
+ "llava_v1": conv_llava_v1,
359
+
360
+ # fastchat
361
+ "v1": conv_v1_2,
362
+ "bair_v1": conv_bair_v1,
363
+ "vicuna_v1_1": conv_vicuna_v1_1,
364
+ "mpt": conv_mpt,
365
+ "mpt_text": conv_mpt_text,
366
+ }
367
+
368
+ if __name__ == "__main__":
369
+ print(default_conversation.get_prompt())
data/fintune_dataset.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import torch
4
+ import yaml
5
+ from torch.utils.data import Dataset
6
+ from PIL import Image
7
+ import json
8
+ from model.tokenizer import Tokenizer
9
+ import os
10
+ import torchvision.transforms as transforms
11
+ import random
12
+ import torchvision.transforms.functional as F
13
+ import torchaudio
14
+ from . import conversation_lib
15
+
16
+ import numpy as np
17
+ from . import video_utils
18
+ from .imu_utils import get_imu_frames
19
+
20
+
21
+ IGNORE_INDEX = -100
22
+
23
+ DEFAULT_IMAGE_TOKEN = "<image>"
24
+ try:
25
+ from torchvision.transforms import InterpolationMode
26
+
27
+ BICUBIC = InterpolationMode.BICUBIC
28
+ except ImportError:
29
+ BICUBIC = Image.BICUBIC
30
+
31
+ T_random_resized_crop = transforms.Compose([
32
+ transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=BICUBIC,
33
+ antialias=None), # 3 is bicubic
34
+ transforms.ToTensor(),
35
+ transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
36
+
37
+
38
+ # image transform
39
+ transform_img_train = transforms.Compose([
40
+ transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(
41
+ 0.75, 1.3333), interpolation=3, antialias=None), # 3 is bicubic
42
+ transforms.ToTensor(),
43
+ transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
44
+
45
+
46
+ class PairRandomResizedCrop(transforms.RandomResizedCrop):
47
+ def forward(self, imgs):
48
+ i, j, h, w = self.get_params(imgs[0], self.scale, self.ratio)
49
+ return [F.resized_crop(img, i, j, h, w, self.size, self.interpolation, antialias=self.antialias) for img in imgs]
50
+
51
+
52
+ class PairToTensor(transforms.ToTensor):
53
+ def __call__(self, pics):
54
+ return [F.to_tensor(pic) for pic in pics]
55
+
56
+
57
+ class PairNormalize(transforms.Normalize):
58
+ def forward(self, tensors):
59
+ return [F.normalize(tensor, self.mean, self.std, self.inplace) for tensor in tensors]
60
+
61
+
62
+ transform_pairimg_train = transforms.Compose([
63
+ PairRandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(
64
+ 0.75, 1.3333), interpolation=3, antialias=None), # 3 is bicubic
65
+ PairToTensor(),
66
+ PairNormalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
67
+
68
+
69
+ def pc_norm(pc):
70
+ """ pc: NxC, return NxC """
71
+ xyz = pc[:, :3]
72
+ other_feature = pc[:, 3:]
73
+
74
+ centroid = torch.mean(xyz, dim=0)
75
+ xyz = xyz - centroid
76
+ m = torch.max(torch.sqrt(torch.sum(xyz ** 2, dim=1)))
77
+ xyz = xyz / m
78
+
79
+ pc = torch.cat((xyz, other_feature), dim=1)
80
+ return pc
81
+
82
+
83
+ def make_audio_features(wav_name, mel_bins=128, target_length=1024, aug=False):
84
+ waveform, sr = torchaudio.load(wav_name)
85
+ # assert sr == 16000, 'input audio sampling rate must be 16kHz'
86
+ if sr != 16000:
87
+ trans = torchaudio.transforms.Resample(sr, 16000)
88
+ waveform = trans(waveform)
89
+
90
+ waveform = waveform - waveform.mean()
91
+
92
+ fbank = torchaudio.compliance.kaldi.fbank(
93
+ waveform, htk_compat=True, sample_frequency=16000, use_energy=False,
94
+ window_type='hanning', num_mel_bins=mel_bins, dither=0.0, frame_shift=10)
95
+
96
+ n_frames = fbank.shape[0]
97
+
98
+ p = target_length - n_frames
99
+ if p > 0:
100
+ m = torch.nn.ZeroPad2d((0, 0, 0, p))
101
+ fbank = m(fbank)
102
+ elif p < 0:
103
+ fbank = fbank[0:target_length, :]
104
+
105
+ if aug:
106
+ freqm = torchaudio.transforms.FrequencyMasking(48)
107
+ timem = torchaudio.transforms.TimeMasking(192)
108
+ fbank = torch.transpose(fbank, 0, 1)
109
+ fbank = fbank.unsqueeze(0)
110
+ fbank = freqm(fbank)
111
+ fbank = timem(fbank)
112
+ fbank = fbank.squeeze(0)
113
+ fbank = torch.transpose(fbank, 0, 1)
114
+
115
+ fbank = (fbank - (-4.2677393)) / (4.5689974 * 2)
116
+ return fbank
117
+
118
+
119
+ class ConversationGenerator:
120
+ def __init__(self, tokenizer):
121
+ self.tokenizer = tokenizer
122
+ self.header = f"{conversation_lib.default_conversation.system}\n\n"
123
+ self._probe_tokenizer_style()
124
+
125
+ def _probe_tokenizer_style(self):
126
+ """
127
+ Given a sentence, e.g. "My darling", some tokenizers will make the space a seperate token,
128
+ while some others will merge the space into the next word, forming a token representing " darling".
129
+ Knowing which style the tokenizer takes is necessary for correct ground-truth label masking.
130
+
131
+ """
132
+ probe = "Probe am I"
133
+ sentence1 = self.tokenizer.encode(conversation_lib.default_conversation.roles[1] + ": " + probe,
134
+ bos=False, eos=False)
135
+ sentence2 = self.tokenizer.encode(probe,
136
+ bos=False, eos=False)
137
+ if sentence1[-len(sentence2):] == sentence2:
138
+ self.space_before_to_predict = False
139
+ else:
140
+ sentence3 = self.tokenizer.encode(" " + probe,
141
+ bos=False, eos=False)
142
+ assert sentence1[-len(sentence3):] == sentence3
143
+ self.space_before_to_predict = True
144
+
145
+ def add_speaker_and_signal(self, source, get_conversation=True):
146
+ """Add speaker and start/end signal on each round."""
147
+ BEGIN_SIGNAL = "### "
148
+ END_SIGNAL = "\n"
149
+ conversation = self.header
150
+
151
+ to_predict_list = []
152
+
153
+ for sentence in source:
154
+ from_str = sentence["from"]
155
+ if from_str.lower() in ["human"]:
156
+ from_str = conversation_lib.default_conversation.roles[0]
157
+ elif from_str.lower() in ["gpt", "assistant"]:
158
+ from_str = conversation_lib.default_conversation.roles[1]
159
+ else:
160
+ raise ValueError(f"unknown dialog role: {from_str.lower()}")
161
+
162
+ value = sentence["value"]
163
+ if DEFAULT_IMAGE_TOKEN in value:
164
+ value = value.replace(DEFAULT_IMAGE_TOKEN, '').strip()
165
+
166
+ sentence_value = BEGIN_SIGNAL + from_str + ": " + value + END_SIGNAL
167
+
168
+ if from_str == conversation_lib.default_conversation.roles[1]:
169
+ to_predict_value = value + END_SIGNAL + "###"
170
+ if self.space_before_to_predict:
171
+ to_predict_value = " " + to_predict_value
172
+ to_predict_list.append(to_predict_value)
173
+
174
+ if get_conversation:
175
+ conversation = conversation + sentence_value
176
+
177
+ conversation = conversation + BEGIN_SIGNAL
178
+ return conversation, to_predict_list
179
+
180
+
181
+ DATASETS = dict(
182
+ image=[
183
+ dict(path="datasets/InstructionTuning/image/llava_v1_5_mix665k_image.json", type='image'),
184
+ dict(path='datasets/InstructionTuning/image/cococap_train.json', type='image'),
185
+ dict(path="datasets/InstructionTuning/image/llava_v1_5_mix665k_text.json", type='text'),
186
+ ],
187
+ audio=[
188
+ dict(path="datasets/InstructionTuning/audio/audiocap_train.json", type='audio'),
189
+ dict(path="datasets/InstructionTuning/audio/audiocap_val.json", type='audio'),
190
+ dict(path="datasets/InstructionTuning/audio/audio_conversation.json", type='audio'),
191
+ ],
192
+ video=[
193
+ dict(path="datasets/InstructionTuning/video/msrvtt_cap_trainval.json", type='video'),
194
+ dict(path="datasets/InstructionTuning/video/msrvtt_cap_test.json", type='video'),
195
+ dict(path="datasets/InstructionTuning/video/msrvtt_vqa_train.json", type='video'),
196
+ dict(path="datasets/InstructionTuning/video/msrvtt_vqa_val.json", type='video'),
197
+ dict(path="datasets/InstructionTuning/video/msrvtt_vqa_test.json", type='video'),
198
+ dict(path="datasets/InstructionTuning/video/video_complex_reasoning_10k.json", type='video'),
199
+ dict(path="datasets/InstructionTuning/video/video_conversation_10k.json", type='video'),
200
+ dict(path="datasets/InstructionTuning/video/video_detail_10k.json", type='video'),
201
+ ],
202
+ point=[
203
+ dict(path="datasets/InstructionTuning/point/pointllm_70k_formated.json", type='point'),
204
+ ],
205
+ rgbd=[
206
+ dict(path="datasets/InstructionTuning/depth_normal/llava_instruct_50k_depth.json", type='rgbd'),
207
+ ],
208
+ rgbn=[
209
+ dict(path="datasets/InstructionTuning/depth_normal/llava_instruct_50k_normal.json", type='rgbn'),
210
+ ],
211
+ imu=[
212
+ dict(path="datasets/InstructionTuning/imu/imu_fixed_50k.json", type='imu'),
213
+ ],
214
+ fmri=[
215
+ dict(path="datasets/InstructionTuning/fmri/fmri_fixed.json", type='fmri'),
216
+ ],
217
+ )
218
+ IMU_PATH = "/mnt/petrelfs/share_data/hanjiaming/ego4d/v2/processed_imu/"
219
+
220
+
221
+ class FinetuneDialogDataset(Dataset):
222
+ def __init__(self, dataset=['image'], transform=T_random_resized_crop, max_words=2048, image_words=30, tokenizer_path=None):
223
+ if isinstance(dataset, str):
224
+ dataset = [dataset]
225
+
226
+ self.dataset = dataset
227
+
228
+ group_ann = {}
229
+ for d in dataset:
230
+ for meta in DATASETS[d]:
231
+ meta_path, meta_type = meta['path'], meta['type']
232
+ meta_ext = os.path.splitext(meta_path)[-1]
233
+ if meta_ext == ".json":
234
+ with open(meta_path) as f:
235
+ meta_l = json.load(f)
236
+ # add data_type
237
+ # this is a temp solution
238
+ new_meta_l = []
239
+ for l in meta_l:
240
+ l['data_type'] = meta_type
241
+ new_meta_l.append(l)
242
+ meta_l = new_meta_l
243
+ elif meta_ext == ".jsonl":
244
+ meta_l = []
245
+ with open(meta_path) as f:
246
+ for i, line in enumerate(f):
247
+ try:
248
+ meta_l.append(json.loads(line))
249
+ except json.decoder.JSONDecodeError as e:
250
+ print(
251
+ f"Error decoding the following jsonl line ({i}):\n{line.rstrip()}", force=True)
252
+ raise e
253
+ else:
254
+ raise NotImplementedError(
255
+ f"Unknown meta file extension: \"{meta_ext}\". "
256
+ f"Currently, .json, .jsonl are supported. "
257
+ "If you are using a supported format, please set the file extension so that the proper parsing "
258
+ "routine can be called."
259
+ )
260
+ if meta_type not in group_ann:
261
+ group_ann[meta_type] = []
262
+ print(f"{meta_path}, type {meta_type}: len {len(meta_l)}")
263
+ group_ann[meta_type] += meta_l
264
+
265
+ # sort group_ann for higher efficiency (items in one global batch with similar length)
266
+ for meta_type, meta_l in group_ann.items():
267
+ meta_l.sort(key=lambda data_item: sum(
268
+ [len(_['value']) for _ in data_item['conversations']]))
269
+
270
+ self.group_ann = group_ann
271
+ self.ann = sum(list(self.group_ann.values()), start=[])
272
+
273
+ self.group_indices = {}
274
+ start_pos = 0
275
+ for meta_type, meta_l in self.group_ann.items():
276
+ self.group_indices[meta_type] = list(
277
+ range(start_pos, start_pos + len(meta_l)))
278
+ start_pos = start_pos + len(meta_l)
279
+
280
+ print(f"total length: {len(self)}")
281
+ self.transform = transform
282
+ print(f"transform:\n{self.transform}")
283
+ self.max_words = max_words
284
+ self.image_words = image_words
285
+ self.tokenizer = Tokenizer(model_path=tokenizer_path)
286
+ self.conversation_generator = ConversationGenerator(self.tokenizer)
287
+
288
+ self.load_funcs = dict(
289
+ image=self.load_image,
290
+ audio=self.load_audio,
291
+ video=self.load_video,
292
+ point=self.load_point,
293
+ rgbd=self.load_rgbx,
294
+ rgbn=self.load_rgbx,
295
+ imu=self.load_imu,
296
+ fmri=self.load_fmri
297
+ )
298
+
299
+ def __len__(self):
300
+ return len(self.ann)
301
+
302
+ def load_image(self, data):
303
+ filename = data['image']
304
+ image = Image.open(filename).convert('RGB')
305
+ image = self.transform(image)
306
+ return image
307
+
308
+ def load_audio(self, data):
309
+ audio_path = data['image']
310
+ fbank = make_audio_features(audio_path, mel_bins=128)
311
+ fbank = fbank.transpose(0, 1)[None] # [1, 128, 1024]
312
+ return fbank
313
+
314
+ def load_video(self, data):
315
+ video_path = data['image']
316
+ video_feats = video_utils.load_and_transform_video_data(
317
+ video_path, video_path, clip_duration=1, clips_per_video=5)
318
+ return video_feats[:, :, 0]
319
+
320
+ def load_point(self, data):
321
+ point_path = data['image']
322
+ point_feat = torch.load(point_path, map_location='cpu')
323
+ point_feat = point_feat.transpose(0, 1)
324
+ return point_feat
325
+
326
+ def load_rgbx(self, data):
327
+ image_path = data['image']
328
+ x_image_path = data['depth_image'] if 'depth_image' in data else data['normal_image']
329
+ image = Image.open(image_path).convert('RGB')
330
+ x_image = Image.open(x_image_path).convert('RGB')
331
+ x_image = x_image.resize(image.size[-2:])
332
+
333
+ image, x_image = transform_pairimg_train([image, x_image])
334
+ # [2, 3, H, W]
335
+ image = torch.stack([image, x_image], dim=0)
336
+ return image
337
+
338
+ def load_fmri(self, data):
339
+ fmri_path = data['image']
340
+ data = np.load(fmri_path)
341
+ data = data.mean(axis=0)
342
+ data = torch.tensor(data[None])
343
+ return data
344
+
345
+ def load_imu(self, data_dict):
346
+ uid = data_dict["video_uid"]
347
+ w_s = data_dict["window_start"]
348
+ w_e = data_dict["window_end"]
349
+
350
+ imu_data = get_imu_frames(
351
+ IMU_PATH, uid,
352
+ video_start_sec=w_s,
353
+ video_end_sec=w_e,
354
+ )
355
+ if imu_data is None:
356
+ raise ValueError
357
+ return imu_data['signal']
358
+
359
+ def __getitem__(self, index, expect_type=None):
360
+ if expect_type is None:
361
+ data_item = self.ann[index]
362
+ else:
363
+ # in case we want get data from specific data_type
364
+ data_item = self.group_ann[expect_type][index]
365
+
366
+ data_type = data_item['data_type']
367
+ if data_type != 'text':
368
+ if data_type in self.load_funcs:
369
+ try:
370
+ image = self.load_funcs[data_type](data_item)
371
+ if image == None:
372
+ raise ValueError('Data is None')
373
+ except:
374
+ print('Error', data_item)
375
+ rand_idx = random.randint(
376
+ 0, len(self.group_ann[data_type]))
377
+ return self.__getitem__(rand_idx, expect_type=data_type)
378
+ else:
379
+ raise ValueError(f'Does not support {data_type}')
380
+ else:
381
+ image = None
382
+ # warnings.warn("pure black image for examples without image")
383
+ # image = torch.zeros(3, 224, 224)
384
+
385
+ source = data_item["conversations"]
386
+ conversation, to_predict_values = self.conversation_generator.add_speaker_and_signal(
387
+ source)
388
+ if len(to_predict_values) == 0:
389
+ warnings.warn(
390
+ f"see dialog data with nothing to predict, data: {data_item}")
391
+ return self[index-1]
392
+
393
+ tokenzed_conversation = self.tokenizer.encode(
394
+ conversation, bos=True, eos=True)
395
+ labels = [IGNORE_INDEX for _ in tokenzed_conversation]
396
+
397
+ check_pos = 0
398
+ for value in to_predict_values:
399
+ tokenized_value = self.tokenizer.encode(
400
+ value, bos=False, eos=False)
401
+ value_pos = find_sublist(
402
+ tokenzed_conversation[check_pos:], tokenized_value) + check_pos
403
+ if value_pos == -1:
404
+ print(
405
+ "a sentence mismatches the corresponding piece in the conversation")
406
+ return self[index-1]
407
+ labels[value_pos:value_pos+len(tokenized_value)] = tokenized_value
408
+ assert labels[value_pos:value_pos+len(
409
+ tokenized_value)] == tokenzed_conversation[value_pos:value_pos+len(tokenized_value)]
410
+ check_pos = value_pos+len(tokenized_value)
411
+
412
+ input2 = torch.tensor(tokenzed_conversation, dtype=torch.int64)
413
+ labels = torch.tensor(labels, dtype=torch.int64)
414
+
415
+ if image is not None:
416
+ max_words = self.max_words - self.image_words
417
+ else:
418
+ max_words = self.max_words
419
+ padding = max_words - input2.shape[0]
420
+ if padding > 0:
421
+ input2 = torch.cat(
422
+ (input2, torch.zeros(padding, dtype=torch.int64) - 1))
423
+ labels = torch.cat(
424
+ (labels, torch.zeros(padding, dtype=torch.int64) - 1))
425
+ elif padding < 0:
426
+ input2 = input2[:max_words]
427
+ labels = labels[:max_words]
428
+
429
+ input2_mask = input2.ge(0)
430
+ label_mask = labels.ge(0)
431
+ input2[~input2_mask] = 0
432
+ labels[~label_mask] = 0
433
+ input2_mask = input2_mask.float()
434
+ label_mask = label_mask.float()
435
+ if image is None:
436
+ return input2, labels, data_item['data_type']
437
+ else:
438
+ return input2, labels, image, data_item['data_type']
439
+
440
+ def groups(self):
441
+ return list(self.group_indices.values())
442
+
443
+
444
+ def find_sublist(a: list, b: list):
445
+ len_a, len_b = len(a), len(b)
446
+ for i in range(len_a - len_b + 1):
447
+ if a[i:i+len_b] == b:
448
+ return i
449
+ return -1
data/imu_utils.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import string
2
+ import numpy as np
3
+ import matplotlib.animation as animation
4
+ from matplotlib import pyplot as plt
5
+ import json
6
+ from collections import defaultdict
7
+ from bisect import bisect_left
8
+ import os
9
+ import torch
10
+ import torchaudio
11
+ torchaudio.set_audio_backend("sox_io")
12
+
13
+
14
+ def load_json(json_path: str):
15
+ """
16
+ Load a json file
17
+ """
18
+ with open(json_path, "r", encoding="utf-8") as f_name:
19
+ data = json.load(f_name)
20
+ return data
21
+
22
+
23
+ def check_window_signal(info_t, w_s, w_e):
24
+ length = w_e - w_s
25
+ frame_offset = int(w_s * info_t.sample_rate)
26
+ num_frames = int(length * info_t.sample_rate)
27
+ if frame_offset + num_frames > int(info_t.num_frames):
28
+ return False
29
+ else:
30
+ return True
31
+
32
+
33
+ def index_narrations(ann_path):
34
+ narration_raw = load_json(ann_path)
35
+
36
+ narration_dict = defaultdict(list)
37
+ summary_dict = defaultdict(list)
38
+ avg_len = []
39
+ for v_id, narr in narration_raw.items():
40
+ narr_list = []
41
+ summ_list = []
42
+ if "narration_pass_1" in narr:
43
+ narr_list += narr["narration_pass_1"]["narrations"]
44
+ summ_list += narr["narration_pass_1"]["summaries"]
45
+ if "narration_pass_2" in narr:
46
+ narr_list += narr["narration_pass_2"]["narrations"]
47
+ summ_list += narr["narration_pass_2"]["summaries"]
48
+
49
+ if len(narr_list) > 0:
50
+ narration_dict[v_id] = [
51
+ (
52
+ float(n_t["timestamp_sec"]),
53
+ n_t["narration_text"],
54
+ n_t["annotation_uid"],
55
+ n_t["timestamp_frame"],
56
+ )
57
+ for n_t in narr_list
58
+ ]
59
+ avg_len.append(len(narration_dict[v_id]))
60
+ else:
61
+ narration_dict[v_id] = []
62
+ if len(summ_list) > 0:
63
+ summary_dict[v_id] = [
64
+ (
65
+ float(s_t["start_sec"]),
66
+ float(s_t["end_sec"]),
67
+ s_t["summary_text"],
68
+ )
69
+ for s_t in summ_list
70
+ ]
71
+ else:
72
+ summary_dict[v_id] = []
73
+ # print(f"Number of Videos with narration {len(narration_dict)}")
74
+ # print(f"Avg. narration length {np.mean(avg_len)}")
75
+ # print(f"Number of Videos with summaries {len(summary_dict)}")
76
+ return narration_dict, summary_dict
77
+
78
+
79
+ def get_signal_info(signal_fn: str):
80
+ return torchaudio.info(signal_fn)
81
+
82
+
83
+ def get_signal_frames(signal_fn: str, video_start_sec: float, video_end_sec: float):
84
+ """
85
+ Given a signal track return the frames between video_start_sec and video_end_sec
86
+ """
87
+ info_t = get_signal_info(signal_fn)
88
+
89
+ length = video_end_sec - video_start_sec
90
+ aframes, _ = torchaudio.load(
91
+ signal_fn,
92
+ normalize=True,
93
+ frame_offset=int(video_start_sec * info_t.sample_rate),
94
+ num_frames=int(length * info_t.sample_rate),
95
+ )
96
+ return {"signal": aframes, "meta": info_t}
97
+
98
+
99
+ def tosec(value):
100
+ return value / 1000
101
+
102
+
103
+ def toms(value):
104
+ return value * 1000
105
+
106
+
107
+ def delta(first_num: float, second_num: float):
108
+ """Compute the absolute value of the difference of two numbers"""
109
+ return abs(first_num - second_num)
110
+
111
+
112
+ def padIMU(signal, duration_sec):
113
+ """
114
+ Pad the signal if necessary
115
+ """
116
+ expected_elements = round(duration_sec) * 200
117
+
118
+ if signal.shape[0] > expected_elements:
119
+ signal = signal[:expected_elements, :]
120
+ elif signal.shape[0] < expected_elements:
121
+ padding = expected_elements - signal.shape[0]
122
+ padded_zeros = np.zeros((padding, 6))
123
+ signal = np.concatenate([signal, padded_zeros], 0)
124
+ # signal = signal[:expected_elements, :]
125
+ return signal
126
+
127
+
128
+ def resample(
129
+ signals: np.ndarray,
130
+ timestamps: np.ndarray,
131
+ original_sample_rate: int,
132
+ resample_rate: int,
133
+ ):
134
+ """
135
+ Resamples data to new sample rate
136
+ """
137
+ signals = torch.as_tensor(signals)
138
+ timestamps = torch.from_numpy(timestamps).unsqueeze(-1)
139
+ signals = torchaudio.functional.resample(
140
+ waveform=signals.data.T,
141
+ orig_freq=original_sample_rate,
142
+ new_freq=resample_rate,
143
+ ).T.numpy()
144
+
145
+ nsamples = len(signals)
146
+
147
+ period = 1 / resample_rate
148
+
149
+ # timestamps are expected to be shape (N, 1)
150
+ initital_seconds = timestamps[0] / 1e3
151
+
152
+ ntimes = (torch.arange(nsamples) * period).view(-1, 1) + initital_seconds
153
+
154
+ timestamps = (ntimes * 1e3).squeeze().numpy()
155
+ return signals, timestamps
156
+
157
+
158
+ def resampleIMU(signal, timestamps):
159
+ sampling_rate = int(1000 * (1 / (np.mean(np.diff(timestamps)))))
160
+ # resample all to 200hz
161
+ if sampling_rate != 200:
162
+ signal, timestamps = resample(signal, timestamps, sampling_rate, 200)
163
+ return signal, timestamps
164
+
165
+
166
+ def get_imu_frames(
167
+ imu_path,
168
+ uid: str,
169
+ video_start_sec: float,
170
+ video_end_sec: float,
171
+ ):
172
+ """
173
+ Given a IMU signal return the frames between video_start_sec and video_end_sec
174
+ """
175
+ signal = np.load(os.path.join(imu_path, f"{uid}.npy"))
176
+ signal = signal.transpose()
177
+ timestamps = np.load(os.path.join(imu_path, f"{uid}_timestamps.npy"))
178
+
179
+ if toms(video_start_sec) > timestamps[-1] or toms(video_end_sec) > timestamps[-1]:
180
+ return None
181
+
182
+ start_id = bisect_left(timestamps, toms(video_start_sec))
183
+ end_id = bisect_left(timestamps, toms(video_end_sec))
184
+
185
+ # make sure the retrieved window interval are correct by a max of 1 sec margin
186
+ if (
187
+ delta(video_start_sec, tosec(timestamps[start_id])) > 4
188
+ or delta(video_end_sec, tosec(timestamps[end_id])) > 4
189
+ ):
190
+ return None
191
+
192
+ # get the window
193
+ if start_id == end_id:
194
+ start_id -= 1
195
+ end_id += 1
196
+ signal, timestamps = signal[start_id:end_id], timestamps[start_id:end_id]
197
+
198
+ if len(signal) < 10 or len(timestamps) < 10:
199
+ return None
200
+ # resample the signal at 200hz if necessary
201
+ signal, timestamps = resampleIMU(signal, timestamps)
202
+
203
+ # pad the signal if necessary
204
+ signal = padIMU(signal, video_end_sec - video_start_sec)
205
+
206
+ sample_dict = {
207
+ "timestamp": timestamps,
208
+ "signal": torch.tensor(signal.T),
209
+ "sampling_rate": 200,
210
+ }
211
+
212
+ return sample_dict
213
+
214
+
215
+ def display_animation(frames, title, save_path_gif):
216
+ fig, ax = plt.subplots()
217
+ frames = [[ax.imshow(frames[i])] for i in range(len(frames))]
218
+ plt.title(title)
219
+ ani = animation.ArtistAnimation(fig, frames)
220
+ ani.save(save_path_gif, writer="imagemagick")
221
+ plt.close()
222
+
223
+
224
+ def display_animation_imu(frames, imu, title, save_path_gif):
225
+ fig, (ax1, ax2, ax3) = plt.subplots(3, 1)
226
+ ax1.set_title(title)
227
+ ax2.set_title("Acc.")
228
+ ax3.set_title("Gyro.")
229
+ frames = [[ax1.imshow(frames[i])] for i in range(len(frames))]
230
+ ani = animation.ArtistAnimation(fig, frames)
231
+
232
+ ax2.plot(imu[0].cpu().numpy(), color="red")
233
+ ax2.plot(imu[1].cpu().numpy(), color="blue")
234
+ ax2.plot(imu[2].cpu().numpy(), color="green")
235
+ ax3.plot(imu[3].cpu().numpy(), color="red")
236
+ ax3.plot(imu[4].cpu().numpy(), color="blue")
237
+ ax3.plot(imu[5].cpu().numpy(), color="green")
238
+ plt.tight_layout()
239
+ ani.save(save_path_gif, writer="imagemagick")
240
+ plt.close()
241
+
242
+
243
+ def filter_narration(narration_text: str) -> bool:
244
+ if "#c" in narration_text.lower():
245
+ return True
246
+ return False
247
+
248
+
249
+ def clean_narration_text(narration_text: str) -> str:
250
+ return (
251
+ narration_text.replace("#C C ", "")
252
+ .replace("#C", "")
253
+ .replace("#unsure", "something")
254
+ .strip()
255
+ .strip(string.punctuation)
256
+ .lower()[:128]
257
+ )
data/video_utils.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from pytorchvideo import transforms as pv_transforms
5
+ from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler
6
+ from pytorchvideo.data.encoded_video import EncodedVideo
7
+ from pytorchvideo.data.encoded_video_decord import EncodedVideoDecord
8
+ from torchvision import transforms
9
+ from torchvision.transforms._transforms_video import NormalizeVideo
10
+
11
+
12
+ def get_clip_timepoints(clip_sampler, duration):
13
+ # Read out all clips in this video
14
+ all_clips_timepoints = []
15
+ is_last_clip = False
16
+ end = 0.0
17
+ while not is_last_clip:
18
+ start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
19
+ all_clips_timepoints.append((start, end))
20
+ return all_clips_timepoints
21
+
22
+
23
+
24
+ def crop_boxes(boxes, x_offset, y_offset):
25
+ """
26
+ Perform crop on the bounding boxes given the offsets.
27
+ Args:
28
+ boxes (ndarray or None): bounding boxes to perform crop. The dimension
29
+ is `num boxes` x 4.
30
+ x_offset (int): cropping offset in the x axis.
31
+ y_offset (int): cropping offset in the y axis.
32
+ Returns:
33
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
34
+ `num boxes` x 4.
35
+ """
36
+ cropped_boxes = boxes.copy()
37
+ cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
38
+ cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
39
+
40
+ return cropped_boxes
41
+
42
+
43
+ def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
44
+ """
45
+ Perform uniform spatial sampling on the images and corresponding boxes.
46
+ Args:
47
+ images (tensor): images to perform uniform crop. The dimension is
48
+ `num frames` x `channel` x `height` x `width`.
49
+ size (int): size of height and weight to crop the images.
50
+ spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
51
+ is larger than height. Or 0, 1, or 2 for top, center, and bottom
52
+ crop if height is larger than width.
53
+ boxes (ndarray or None): optional. Corresponding boxes to images.
54
+ Dimension is `num boxes` x 4.
55
+ scale_size (int): optinal. If not None, resize the images to scale_size before
56
+ performing any crop.
57
+ Returns:
58
+ cropped (tensor): images with dimension of
59
+ `num frames` x `channel` x `size` x `size`.
60
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
61
+ `num boxes` x 4.
62
+ """
63
+ assert spatial_idx in [0, 1, 2]
64
+ ndim = len(images.shape)
65
+ if ndim == 3:
66
+ images = images.unsqueeze(0)
67
+ height = images.shape[2]
68
+ width = images.shape[3]
69
+
70
+ if scale_size is not None:
71
+ if width <= height:
72
+ width, height = scale_size, int(height / width * scale_size)
73
+ else:
74
+ width, height = int(width / height * scale_size), scale_size
75
+ images = torch.nn.functional.interpolate(
76
+ images,
77
+ size=(height, width),
78
+ mode="bilinear",
79
+ align_corners=False,
80
+ )
81
+
82
+ y_offset = int(math.ceil((height - size) / 2))
83
+ x_offset = int(math.ceil((width - size) / 2))
84
+
85
+ if height > width:
86
+ if spatial_idx == 0:
87
+ y_offset = 0
88
+ elif spatial_idx == 2:
89
+ y_offset = height - size
90
+ else:
91
+ if spatial_idx == 0:
92
+ x_offset = 0
93
+ elif spatial_idx == 2:
94
+ x_offset = width - size
95
+ cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
96
+ cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
97
+ if ndim == 3:
98
+ cropped = cropped.squeeze(0)
99
+ return cropped, cropped_boxes
100
+
101
+
102
+ class SpatialCrop(nn.Module):
103
+ """
104
+ Convert the video into 3 smaller clips spatially. Must be used after the
105
+ temporal crops to get spatial crops, and should be used with
106
+ -2 in the spatial crop at the slowfast augmentation stage (so full
107
+ frames are passed in here). Will return a larger list with the
108
+ 3x spatial crops as well.
109
+ """
110
+
111
+ def __init__(self, crop_size: int = 224, num_crops: int = 3):
112
+ super().__init__()
113
+ self.crop_size = crop_size
114
+ if num_crops == 3:
115
+ self.crops_to_ext = [0, 1, 2]
116
+ self.flipped_crops_to_ext = []
117
+ elif num_crops == 1:
118
+ self.crops_to_ext = [1]
119
+ self.flipped_crops_to_ext = []
120
+ else:
121
+ raise NotImplementedError("Nothing else supported yet")
122
+
123
+ def forward(self, videos):
124
+ """
125
+ Args:
126
+ videos: A list of C, T, H, W videos.
127
+ Returns:
128
+ videos: A list with 3x the number of elements. Each video converted
129
+ to C, T, H', W' by spatial cropping.
130
+ """
131
+ assert isinstance(videos, list), "Must be a list of videos after temporal crops"
132
+ assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)"
133
+ res = []
134
+ for video in videos:
135
+ for spatial_idx in self.crops_to_ext:
136
+ res.append(uniform_crop(video, self.crop_size, spatial_idx)[0])
137
+ if not self.flipped_crops_to_ext:
138
+ continue
139
+ flipped_video = transforms.functional.hflip(video)
140
+ for spatial_idx in self.flipped_crops_to_ext:
141
+ res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0])
142
+ return res
143
+
144
+
145
+ def load_and_transform_video_data(
146
+ video_file,
147
+ video_path,
148
+ clip_duration=2,
149
+ clips_per_video=5,
150
+ sample_rate=16000,
151
+ with_audio=False
152
+ ):
153
+ video_transform = transforms.Compose(
154
+ [
155
+ pv_transforms.ShortSideScale(224),
156
+ NormalizeVideo(
157
+ mean=(0.48145466, 0.4578275, 0.40821073),
158
+ std=(0.26862954, 0.26130258, 0.27577711),
159
+ ),
160
+ ]
161
+ )
162
+
163
+ clip_sampler = ConstantClipsPerVideoSampler(
164
+ clip_duration=clip_duration, clips_per_video=clips_per_video
165
+ )
166
+ frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration)
167
+
168
+ if isinstance(video_file, str):
169
+ video = EncodedVideo.from_path(
170
+ video_file,
171
+ decoder="decord",
172
+ decode_audio=with_audio,
173
+ # **{"sample_rate": sample_rate},
174
+ )
175
+ else:
176
+ video = EncodedVideoDecord(video_file, video_name=video_path, decode_video=True, decode_audio=with_audio, sample_rate=sample_rate)
177
+
178
+ all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration)
179
+
180
+ all_video = []
181
+ for clip_timepoints in all_clips_timepoints:
182
+ # Read the clip, get frames
183
+ clip = video.get_clip(clip_timepoints[0], clip_timepoints[1])
184
+ if clip is None:
185
+ raise ValueError("No clip found")
186
+ video_clip = frame_sampler(clip["video"])
187
+ video_clip = video_clip / 255.0 # since this is float, need 0-1
188
+
189
+ all_video.append(video_clip)
190
+
191
+ all_video = [video_transform(clip) for clip in all_video]
192
+ all_video = SpatialCrop(224, num_crops=3)(all_video)
193
+
194
+ all_video = torch.stack(all_video, dim=0)
195
+
196
+ if not with_audio:
197
+ return all_video
198
+ else:
199
+ return all_video, clip['audio']
200
+
201
+ if __name__ == '__main__':
202
+ video_path = "datasets/InstructionTuning/video/music_aqa/MUSIC-AVQA-videos-Real/00000002.mp4"
203
+ video, audio = load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5, with_audio=True)
204
+ import pdb;pdb.set_trace()
demos/multi_turn_mm.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ sys.path.append(os.path.abspath(__file__).rsplit('/', 2)[0])
4
+
5
+ import argparse
6
+ import multiprocessing as mp
7
+ import numpy as np
8
+ from typing import List, Optional
9
+
10
+ import torch
11
+ import torch.distributed as dist
12
+
13
+ from fairscale.nn.model_parallel import initialize as fs_init
14
+
15
+ import gradio as gr
16
+ from util.misc import setup_for_distributed
17
+ from util.misc import default_tensor_type
18
+ from model.meta import MetaModel
19
+ from data.conversation_lib import conv_templates, SeparatorStyle
20
+ from PIL import Image
21
+ import torchvision.transforms as transforms
22
+ from data.fintune_dataset import make_audio_features
23
+ from data import video_utils
24
+
25
+
26
+ T_random_resized_crop = transforms.Compose([
27
+ transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=3,
28
+ antialias=None), # 3 is bicubic
29
+ transforms.ToTensor(),
30
+ transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
31
+
32
+
33
+ def load_audio(audio_path):
34
+ fbank = make_audio_features(audio_path, mel_bins=128)
35
+ fbank = fbank.transpose(0, 1)[None] #[1, 128, 1024]
36
+ return fbank
37
+
38
+ def load_video(video_path):
39
+ video_feats = video_utils.load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5)
40
+ return video_feats[:, :, 0]
41
+
42
+
43
+ def model_worker(
44
+ rank: int, args: argparse.Namespace, barrier: mp.Barrier,
45
+ request_queue: mp.Queue, response_queue: Optional[mp.Queue] = None,
46
+ ) -> None:
47
+ """
48
+ The worker function that manipulates the GPU to run the inference.
49
+ Exact n_gpu workers are started, with each one operating on a separate GPU.
50
+
51
+ Args:
52
+ rank (int): Distributed rank of the worker.
53
+ args (argparse.Namespace): All command line arguments.
54
+ barrier (multiprocessing.Barrier): A barrier used to delay the start
55
+ of Web UI to be after the start of the model.
56
+ """
57
+
58
+ world_size = len(args.gpu_ids)
59
+ gpu_id = args.gpu_ids[rank]
60
+ dist.init_process_group(
61
+ backend="nccl", rank=rank, world_size=world_size,
62
+ init_method=f"tcp://{args.master_addr}:{args.master_port}",
63
+ )
64
+ print(f"| distributed init on worker {rank}/{world_size}. "
65
+ f"using gpu: {gpu_id}")
66
+ fs_init.initialize_model_parallel(world_size)
67
+ torch.cuda.set_device(gpu_id)
68
+
69
+ torch.manual_seed(1)
70
+ np.random.seed(1)
71
+
72
+ # set the print behavior.
73
+ setup_for_distributed(rank == 0)
74
+
75
+ target_dtype = {
76
+ "bf16": torch.bfloat16,
77
+ "fp16": torch.float16
78
+ }[args.dtype]
79
+ with default_tensor_type(dtype=target_dtype, device="cuda"):
80
+ model = MetaModel(args.llama_type, args.llama_config, tokenizer_path=args.tokenizer_path)
81
+ print("Loading pretrained weights ...")
82
+ checkpoint = torch.load(args.pretrained_path, map_location='cpu')
83
+ msg = model.load_state_dict(checkpoint, strict=False)
84
+ print("load result:\n", msg)
85
+ model.cuda()
86
+ model.eval()
87
+ print(f"Model = {str(model)}")
88
+
89
+ barrier.wait()
90
+
91
+ while True:
92
+ img_path, audio_path, video_path, chatbot, max_gen_len, temperature, top_p, modality = request_queue.get()
93
+ if 'image' in modality and img_path is not None:
94
+ image = Image.open(img_path).convert('RGB')
95
+ inputs = T_random_resized_crop(image)
96
+ elif 'video' in modality and video_path is not None:
97
+ inputs = load_video(video_path)
98
+ elif 'audio' in modality and audio_path is not None:
99
+ inputs = load_audio(audio_path)
100
+ else:
101
+ inputs = None
102
+
103
+ if inputs is not None:
104
+ inputs = inputs[None].cuda().to(target_dtype)
105
+
106
+ conv = conv_templates["v1"].copy()
107
+ for user, bot in chatbot:
108
+ conv.append_message(conv.roles[0], user)
109
+ conv.append_message(conv.roles[1], bot)
110
+
111
+ with torch.cuda.amp.autocast(dtype=target_dtype):
112
+ print(conv.get_prompt())
113
+ for stream_response in model.stream_generate(
114
+ conv.get_prompt(), inputs,
115
+ max_gen_len=max_gen_len, temperature=temperature, top_p=top_p,
116
+ modal = modality
117
+ ):
118
+ conv_sep = (
119
+ conv.sep
120
+ if conv.sep_style == SeparatorStyle.SINGLE
121
+ else conv.sep2
122
+ )
123
+ end_pos = stream_response["text"].find(conv_sep)
124
+ if end_pos != -1:
125
+ stream_response["text"] = (
126
+ stream_response['text'][:end_pos].rstrip() + "\n"
127
+ )
128
+ stream_response["end_of_content"] = True
129
+
130
+ # keep a few characters if not end_of_content to avoid sending
131
+ # part of conv_sep before all of it is generated.
132
+ if not stream_response["end_of_content"]:
133
+ if len(stream_response["text"]) < len(conv_sep):
134
+ continue
135
+ stream_response["text"] = (
136
+ stream_response["text"][:-len(conv_sep)]
137
+ )
138
+
139
+ if response_queue is not None:
140
+ response_queue.put(stream_response)
141
+
142
+ if stream_response["end_of_content"]:
143
+ break
144
+
145
+
146
+ def gradio_worker(
147
+ request_queues: List[mp.Queue], response_queue: mp.Queue,
148
+ args: argparse.Namespace, barrier: mp.Barrier,
149
+ ) -> None:
150
+ """
151
+ The gradio worker is responsible for displaying the WebUI and relay the
152
+ requests to model workers. It should be launched only once.
153
+
154
+ Args:
155
+ request_queues (List[mp.Queue]): A list of request queues (one for
156
+ each model worker).
157
+ args (argparse.Namespace): All command line arguments.
158
+ barrier (multiprocessing.Barrier): A barrier used to delay the start
159
+ of Web UI to be after the start of the model.
160
+ """
161
+
162
+ def show_user_input(msg, chatbot):
163
+ return "", chatbot + [[msg, None]]
164
+
165
+ def stream_model_output(img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality):
166
+ for queue in request_queues:
167
+ queue.put((img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality))
168
+ while True:
169
+ content_piece = response_queue.get()
170
+ chatbot[-1][1] = content_piece["text"]
171
+ yield chatbot
172
+ if content_piece["end_of_content"]:
173
+ break
174
+
175
+ def undo(chatbot):
176
+ if len(chatbot) > 0:
177
+ chatbot = chatbot[:-1]
178
+ return chatbot
179
+
180
+ def clear():
181
+ chatbot = []
182
+ msg = ""
183
+ return chatbot, msg
184
+
185
+ CSS ="""
186
+ .contain { display: flex; flex-direction: column; }
187
+ #component-0 { height: 100%; }
188
+ #chatbot { flex-grow: 1; overflow: auto;}
189
+ """
190
+ with gr.Blocks(css=CSS) as demo:
191
+ gr.Markdown("## OneLLM: One Framework to Align All Modalities with Language")
192
+ with gr.Row(equal_height=True):
193
+ with gr.Column(scale=1):
194
+ img_path = gr.Image(label='Image Input', type='filepath')
195
+ video_path = gr.Video(label='Video Input')
196
+ audio_path = gr.Audio(label='Audio Input', type='filepath', sources=['upload'])
197
+ modality = gr.Radio(choices=['image', 'audio', 'video'], value='image', interactive=True, label='Input Modalities')
198
+
199
+ with gr.Column(scale=2):
200
+ chatbot = gr.Chatbot(elem_id="chatbot")
201
+ msg = gr.Textbox()
202
+
203
+ with gr.Row():
204
+ submit_button = gr.Button("Submit", variant="primary")
205
+ undo_button = gr.Button("Undo")
206
+ clear_button = gr.ClearButton([chatbot, msg, img_path, audio_path, video_path, modality])
207
+ with gr.Row():
208
+ max_gen_len = gr.Slider(
209
+ minimum=1, maximum=args.model_max_seq_len // 2,
210
+ value=args.model_max_seq_len // 2, interactive=True,
211
+ label="Single-turn max response length",
212
+ )
213
+ gen_t = gr.Slider(
214
+ minimum=0, maximum=1, value=0.1, interactive=True,
215
+ label="Temperature",
216
+ )
217
+ top_p = gr.Slider(
218
+ minimum=0, maximum=1, value=0.75, interactive=True,
219
+ label="Top-p",
220
+ )
221
+ msg.submit(
222
+ show_user_input, [msg, chatbot], [msg, chatbot],
223
+ ).then(
224
+ stream_model_output, [img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot,
225
+ )
226
+ submit_button.click(
227
+ show_user_input, [msg, chatbot], [msg, chatbot],
228
+ ).then(
229
+ stream_model_output, [img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot,
230
+ )
231
+ undo_button.click(undo, chatbot, chatbot)
232
+ # img_path.change(clear, [], [chatbot, msg])
233
+ barrier.wait()
234
+ demo.queue(api_open=True).launch(share=True, max_threads=1)
235
+
236
+
237
+ if __name__ == "__main__":
238
+ parser = argparse.ArgumentParser("Chat Demo")
239
+ group = parser.add_mutually_exclusive_group()
240
+ group.add_argument(
241
+ "--gpu_ids", type=int, nargs="+",
242
+ help="A list of space-separated gpu ids to run the model on. "
243
+ "The model will span across GPUs in tensor-parallel mode."
244
+ )
245
+ parser.add_argument(
246
+ "--tokenizer_path", type=str,
247
+ help="Path to the tokenizer.model file provided along with the LLaMA "
248
+ "model."
249
+ )
250
+ parser.add_argument(
251
+ "--llama_type", default="onellm", type=str, metavar="MODEL",
252
+ help="LLaMA model type."
253
+ )
254
+ parser.add_argument(
255
+ "--llama_config", type=str, required=True,
256
+ help="Path to the llama model config json."
257
+ )
258
+ parser.add_argument(
259
+ "--model_max_seq_len", type=int, default=2048,
260
+ help="Max sequence length accepted by the pretrained model."
261
+ )
262
+ parser.add_argument(
263
+ "--pretrained_path", type=str, required=True,
264
+ help="Path to the llama model checkpoints. A list of checkpoints is "
265
+ "supported and will be merged from left to right.")
266
+ parser.add_argument(
267
+ "--master_port", type=int, default=23862,
268
+ help="A port used by the PyTorch distributed module to initialize."
269
+ )
270
+ parser.add_argument(
271
+ "--master_addr", type=str, default="127.0.0.1",
272
+ help="An address used by the PyTorch distributed module to initialize."
273
+ )
274
+ parser.add_argument(
275
+ "--dtype", type=str, choices=["fp16", "bf16"], default="fp16",
276
+ help="The dtype used for model weights and inference."
277
+ )
278
+ args = parser.parse_args()
279
+
280
+ # using the default "fork" method messes up some imported libs (e.g.,
281
+ # pandas)
282
+ mp.set_start_method("spawn")
283
+
284
+ # setup the queues and start the model workers
285
+ request_queues = []
286
+ response_queue = mp.Queue()
287
+ worker_processes = []
288
+ barrier = mp.Barrier(len(args.gpu_ids) + 1)
289
+ for rank, gpu_id in enumerate(args.gpu_ids):
290
+ request_queue = mp.Queue()
291
+ rank_response_queue = response_queue if rank == 0 else None
292
+ process = mp.Process(
293
+ target=model_worker,
294
+ args=(rank, args, barrier, request_queue, rank_response_queue),
295
+ )
296
+ process.start()
297
+ worker_processes.append(process)
298
+ request_queues.append(request_queue)
299
+
300
+ gradio_worker(request_queues, response_queue, args, barrier)
lib/__pycache__/point_utils.cpython-310.pyc ADDED
Binary file (6.74 kB). View file
 
lib/point_utils.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.autograd import Function
4
+ import pointnet2_cuda
5
+
6
+ class KNN(nn.Module):
7
+ def __init__(self, neighbors, transpose_mode=True):
8
+ super(KNN, self).__init__()
9
+ self.neighbors = neighbors
10
+
11
+ @torch.no_grad()
12
+ def forward(self, support, query):
13
+ """
14
+ Args:
15
+ support ([tensor]): [B, N, C]
16
+ query ([tensor]): [B, M, C]
17
+ Returns:
18
+ [int]: neighbor idx. [B, M, K]
19
+ """
20
+ dist = torch.cdist(support, query)
21
+ k_dist = dist.topk(k=self.neighbors, dim=1, largest=False)
22
+ return k_dist.values, k_dist.indices.transpose(1, 2).contiguous().int()
23
+
24
+
25
+ class GroupingOperation(Function):
26
+
27
+ @staticmethod
28
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
29
+ def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
30
+ """
31
+ :param ctx:
32
+ :param features: (B, C, N) tensor of features to group
33
+ :param idx: (B, npoint, nsample) tensor containing the indicies of features to group with
34
+ :return:
35
+ output: (B, C, npoint, nsample) tensor
36
+ """
37
+ assert features.is_contiguous()
38
+ assert idx.is_contiguous()
39
+
40
+ B, nfeatures, nsample = idx.size()
41
+ _, C, N = features.size()
42
+ output = torch.cuda.FloatTensor(B, C, nfeatures, nsample, device=features.device)
43
+
44
+ pointnet2_cuda.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output)
45
+
46
+ ctx.for_backwards = (idx, N)
47
+ return output
48
+
49
+ @staticmethod
50
+ def backward(ctx, grad_out: torch.Tensor):
51
+ """
52
+ :param ctx:
53
+ :param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward
54
+ :return:
55
+ grad_features: (B, C, N) gradient of the features
56
+ """
57
+ idx, N = ctx.for_backwards
58
+
59
+ B, C, npoint, nsample = grad_out.size()
60
+ grad_features = torch.zeros([B, C, N], dtype=torch.float, device=grad_out.device, requires_grad=True)
61
+ grad_out_data = grad_out.data.contiguous()
62
+ pointnet2_cuda.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data)
63
+ return grad_features, None
64
+
65
+ grouping_operation = GroupingOperation.apply
66
+
67
+
68
+ class KNNGroup(nn.Module):
69
+ def __init__(self, nsample: int,
70
+ relative_xyz=True,
71
+ normalize_dp=False,
72
+ return_only_idx=False,
73
+ **kwargs
74
+ ):
75
+ """[summary]
76
+
77
+ Args:
78
+ nsample (int): maximum number of features to gather in the ball
79
+ use_xyz (bool, optional): concate xyz. Defaults to True.
80
+ ret_grouped_xyz (bool, optional): [description]. Defaults to False.
81
+ normalize_dp (bool, optional): [description]. Defaults to False.
82
+ """
83
+ super().__init__()
84
+ self.nsample = nsample
85
+ self.knn = KNN(nsample, transpose_mode=True)
86
+ self.relative_xyz = relative_xyz
87
+ self.normalize_dp = normalize_dp
88
+ self.return_only_idx = return_only_idx
89
+
90
+ def forward(self, query_xyz: torch.Tensor, support_xyz: torch.Tensor, features: torch.Tensor = None):
91
+ """
92
+ :param query_xyz: (B, N, 3) xyz coordinates of the features
93
+ :param support_xyz: (B, npoint, 3) centroids
94
+ :param features: (B, C, N) descriptors of the features
95
+ :return:
96
+ new_features: (B, 3 + C, npoint, nsample)
97
+ """
98
+ _, idx = self.knn(support_xyz, query_xyz)
99
+ if self.return_only_idx:
100
+ return idx
101
+ idx = idx.int()
102
+ xyz_trans = support_xyz.transpose(1, 2).contiguous()
103
+ grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
104
+ if self.relative_xyz:
105
+ grouped_xyz -= query_xyz.transpose(1, 2).unsqueeze(-1) # relative position
106
+ if self.normalize_dp:
107
+ grouped_xyz /= torch.amax(torch.sqrt(torch.sum(grouped_xyz**2, dim=1)), dim=(1, 2)).view(-1, 1, 1, 1)
108
+ if features is not None:
109
+ grouped_features = grouping_operation(features, idx)
110
+ return grouped_xyz, grouped_features
111
+ else:
112
+ return grouped_xyz, None
113
+
114
+
115
+ class FurthestPointSampling(Function):
116
+ @staticmethod
117
+ def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
118
+ """
119
+ Uses iterative furthest point sampling to select a set of npoint features that have the largest
120
+ minimum distance
121
+ :param ctx:
122
+ :param xyz: (B, N, 3) where N > npoint
123
+ :param npoint: int, number of features in the sampled set
124
+ :return:
125
+ output: (B, npoint) tensor containing the set (idx)
126
+ """
127
+ assert xyz.is_contiguous()
128
+
129
+ B, N, _ = xyz.size()
130
+ # output = torch.cuda.IntTensor(B, npoint, device=xyz.device)
131
+ # temp = torch.cuda.FloatTensor(B, N, device=xyz.device).fill_(1e10)
132
+ output = torch.cuda.IntTensor(B, npoint)
133
+ temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
134
+
135
+ pointnet2_cuda.furthest_point_sampling_wrapper(
136
+ B, N, npoint, xyz, temp, output)
137
+ return output
138
+
139
+ @staticmethod
140
+ def backward(xyz, a=None):
141
+ return None, None
142
+
143
+ furthest_point_sample = FurthestPointSampling.apply
144
+
145
+
146
+ class PointPatchEmbed(nn.Module):
147
+
148
+ def __init__(self,
149
+ sample_ratio=0.0625,
150
+ sample_number=1024,
151
+ group_size=32,
152
+ in_channels=6,
153
+ channels=1024,
154
+ kernel_size=1,
155
+ stride=1,
156
+ normalize_dp=False,
157
+ relative_xyz=True,
158
+ ):
159
+ super().__init__()
160
+ self.sample_ratio = sample_ratio
161
+ self.sample_number = sample_number
162
+ self.group_size = group_size
163
+
164
+ self.sample_fn = furthest_point_sample
165
+ self.grouper = KNNGroup(self.group_size, relative_xyz=relative_xyz, normalize_dp=normalize_dp)
166
+
167
+ self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=kernel_size, stride=stride)
168
+
169
+
170
+ def forward(self, x):
171
+ # coordinates
172
+ p = x[:, :, 3:].contiguous()
173
+
174
+ B, N, _ = p.shape[:3]
175
+ # idx = self.sample_fn(p, int(N * self.sample_ratio)).long()
176
+ idx = self.sample_fn(p, self.sample_number).long()
177
+ center_p = torch.gather(p, 1, idx.unsqueeze(-1).expand(-1, -1, 3))
178
+ # query neighbors.
179
+ _, fj = self.grouper(center_p, p, x.permute(0, 2, 1).contiguous()) # [B, N, 6] -> [B, 6, N] -> [B, 6, 1024, 32]
180
+
181
+ # [B, 6, 1024] -> [B, channels, 1024, 1]
182
+ fj = self.conv1(fj).max(dim=-1, keepdim=True)[0]
183
+
184
+ return fj
185
+
186
+
187
+ if __name__ == '__main__':
188
+ model = PointPatchEmbed(channels=256).cuda()
189
+ input = torch.rand(4, 16384, 6).cuda()
190
+ ou = model(input)
191
+ import pdb;pdb.set_trace()
lib/pointnet2/pointnet2_modules.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from . import pointnet2_utils
6
+ from . import pytorch_utils as pt_utils
7
+ from typing import List
8
+
9
+
10
+ class _PointnetSAModuleBase(nn.Module):
11
+
12
+ def __init__(self):
13
+ super().__init__()
14
+ self.npoint = None
15
+ self.groupers = None
16
+ self.mlps = None
17
+ self.pool_method = 'max_pool'
18
+
19
+ def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, new_xyz=None) -> (torch.Tensor, torch.Tensor):
20
+ """
21
+ :param xyz: (B, N, 3) tensor of the xyz coordinates of the features
22
+ :param features: (B, N, C) tensor of the descriptors of the the features
23
+ :param new_xyz:
24
+ :return:
25
+ new_xyz: (B, npoint, 3) tensor of the new features' xyz
26
+ new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors
27
+ """
28
+ new_features_list = []
29
+
30
+ xyz_flipped = xyz.transpose(1, 2).contiguous()
31
+ if new_xyz is None:
32
+ new_xyz = pointnet2_utils.gather_operation(
33
+ xyz_flipped,
34
+ pointnet2_utils.furthest_point_sample(xyz, self.npoint)
35
+ ).transpose(1, 2).contiguous() if self.npoint is not None else None
36
+
37
+ for i in range(len(self.groupers)):
38
+ new_features = self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample)
39
+
40
+ new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
41
+ if self.pool_method == 'max_pool':
42
+ new_features = F.max_pool2d(
43
+ new_features, kernel_size=[1, new_features.size(3)]
44
+ ) # (B, mlp[-1], npoint, 1)
45
+ elif self.pool_method == 'avg_pool':
46
+ new_features = F.avg_pool2d(
47
+ new_features, kernel_size=[1, new_features.size(3)]
48
+ ) # (B, mlp[-1], npoint, 1)
49
+ else:
50
+ raise NotImplementedError
51
+
52
+ new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint)
53
+ new_features_list.append(new_features)
54
+
55
+ return new_xyz, torch.cat(new_features_list, dim=1)
56
+
57
+
58
+ class PointnetSAModuleMSG(_PointnetSAModuleBase):
59
+ """Pointnet set abstraction layer with multiscale grouping"""
60
+
61
+ def __init__(self, *, npoint: int, radii: List[float], nsamples: List[int], mlps: List[List[int]], bn: bool = True,
62
+ use_xyz: bool = True, pool_method='max_pool', instance_norm=False):
63
+ """
64
+ :param npoint: int
65
+ :param radii: list of float, list of radii to group with
66
+ :param nsamples: list of int, number of samples in each ball query
67
+ :param mlps: list of list of int, spec of the pointnet before the global pooling for each scale
68
+ :param bn: whether to use batchnorm
69
+ :param use_xyz:
70
+ :param pool_method: max_pool / avg_pool
71
+ :param instance_norm: whether to use instance_norm
72
+ """
73
+ super().__init__()
74
+
75
+ assert len(radii) == len(nsamples) == len(mlps)
76
+
77
+ self.npoint = npoint
78
+ self.groupers = nn.ModuleList()
79
+ self.mlps = nn.ModuleList()
80
+ for i in range(len(radii)):
81
+ radius = radii[i]
82
+ nsample = nsamples[i]
83
+ self.groupers.append(
84
+ pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
85
+ if npoint is not None else pointnet2_utils.GroupAll(use_xyz)
86
+ )
87
+ mlp_spec = mlps[i]
88
+ if use_xyz:
89
+ mlp_spec[0] += 3
90
+
91
+ self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn, instance_norm=instance_norm))
92
+ self.pool_method = pool_method
93
+
94
+
95
+ class PointnetSAModule(PointnetSAModuleMSG):
96
+ """Pointnet set abstraction layer"""
97
+
98
+ def __init__(self, *, mlp: List[int], npoint: int = None, radius: float = None, nsample: int = None,
99
+ bn: bool = True, use_xyz: bool = True, pool_method='max_pool', instance_norm=False):
100
+ """
101
+ :param mlp: list of int, spec of the pointnet before the global max_pool
102
+ :param npoint: int, number of features
103
+ :param radius: float, radius of ball
104
+ :param nsample: int, number of samples in the ball query
105
+ :param bn: whether to use batchnorm
106
+ :param use_xyz:
107
+ :param pool_method: max_pool / avg_pool
108
+ :param instance_norm: whether to use instance_norm
109
+ """
110
+ super().__init__(
111
+ mlps=[mlp], npoint=npoint, radii=[radius], nsamples=[nsample], bn=bn, use_xyz=use_xyz,
112
+ pool_method=pool_method, instance_norm=instance_norm
113
+ )
114
+
115
+
116
+ class PointnetFPModule(nn.Module):
117
+ r"""Propigates the features of one set to another"""
118
+
119
+ def __init__(self, *, mlp: List[int], bn: bool = True):
120
+ """
121
+ :param mlp: list of int
122
+ :param bn: whether to use batchnorm
123
+ """
124
+ super().__init__()
125
+ self.mlp = pt_utils.SharedMLP(mlp, bn=bn)
126
+
127
+ def forward(
128
+ self, unknown: torch.Tensor, known: torch.Tensor, unknow_feats: torch.Tensor, known_feats: torch.Tensor
129
+ ) -> torch.Tensor:
130
+ """
131
+ :param unknown: (B, n, 3) tensor of the xyz positions of the unknown features
132
+ :param known: (B, m, 3) tensor of the xyz positions of the known features
133
+ :param unknow_feats: (B, C1, n) tensor of the features to be propigated to
134
+ :param known_feats: (B, C2, m) tensor of features to be propigated
135
+ :return:
136
+ new_features: (B, mlp[-1], n) tensor of the features of the unknown features
137
+ """
138
+ if known is not None:
139
+ dist, idx = pointnet2_utils.three_nn(unknown, known)
140
+ dist_recip = 1.0 / (dist + 1e-8)
141
+ norm = torch.sum(dist_recip, dim=2, keepdim=True)
142
+ weight = dist_recip / norm
143
+
144
+ interpolated_feats = pointnet2_utils.three_interpolate(known_feats, idx, weight)
145
+ else:
146
+ interpolated_feats = known_feats.expand(*known_feats.size()[0:2], unknown.size(1))
147
+
148
+ if unknow_feats is not None:
149
+ new_features = torch.cat([interpolated_feats, unknow_feats], dim=1) # (B, C2 + C1, n)
150
+ else:
151
+ new_features = interpolated_feats
152
+
153
+ new_features = new_features.unsqueeze(-1)
154
+ new_features = self.mlp(new_features)
155
+
156
+ return new_features.squeeze(-1)
157
+
158
+
159
+ if __name__ == "__main__":
160
+ pass
lib/pointnet2/pointnet2_utils.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.autograd import Variable
3
+ from torch.autograd import Function
4
+ import torch.nn as nn
5
+ from typing import Tuple
6
+
7
+ import pointnet2_cuda as pointnet2
8
+
9
+
10
+ class FurthestPointSampling(Function):
11
+ @staticmethod
12
+ def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
13
+ """
14
+ Uses iterative furthest point sampling to select a set of npoint features that have the largest
15
+ minimum distance
16
+ :param ctx:
17
+ :param xyz: (B, N, 3) where N > npoint
18
+ :param npoint: int, number of features in the sampled set
19
+ :return:
20
+ output: (B, npoint) tensor containing the set
21
+ """
22
+ assert xyz.is_contiguous()
23
+
24
+ B, N, _ = xyz.size()
25
+ output = torch.cuda.IntTensor(B, npoint)
26
+ temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
27
+
28
+ pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
29
+ return output
30
+
31
+ @staticmethod
32
+ def backward(xyz, a=None):
33
+ return None, None
34
+
35
+
36
+ furthest_point_sample = FurthestPointSampling.apply
37
+
38
+
39
+ class GatherOperation(Function):
40
+
41
+ @staticmethod
42
+ def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
43
+ """
44
+ :param ctx:
45
+ :param features: (B, C, N)
46
+ :param idx: (B, npoint) index tensor of the features to gather
47
+ :return:
48
+ output: (B, C, npoint)
49
+ """
50
+ assert features.is_contiguous()
51
+ assert idx.is_contiguous()
52
+
53
+ B, npoint = idx.size()
54
+ _, C, N = features.size()
55
+ output = torch.cuda.FloatTensor(B, C, npoint)
56
+
57
+ pointnet2.gather_points_wrapper(B, C, N, npoint, features, idx, output)
58
+
59
+ ctx.for_backwards = (idx, C, N)
60
+ return output
61
+
62
+ @staticmethod
63
+ def backward(ctx, grad_out):
64
+ idx, C, N = ctx.for_backwards
65
+ B, npoint = idx.size()
66
+
67
+ grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
68
+ grad_out_data = grad_out.data.contiguous()
69
+ pointnet2.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data)
70
+ return grad_features, None
71
+
72
+
73
+ gather_operation = GatherOperation.apply
74
+
75
+
76
+ class ThreeNN(Function):
77
+
78
+ @staticmethod
79
+ def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
80
+ """
81
+ Find the three nearest neighbors of unknown in known
82
+ :param ctx:
83
+ :param unknown: (B, N, 3)
84
+ :param known: (B, M, 3)
85
+ :return:
86
+ dist: (B, N, 3) l2 distance to the three nearest neighbors
87
+ idx: (B, N, 3) index of 3 nearest neighbors
88
+ """
89
+ assert unknown.is_contiguous()
90
+ assert known.is_contiguous()
91
+
92
+ B, N, _ = unknown.size()
93
+ m = known.size(1)
94
+ dist2 = torch.cuda.FloatTensor(B, N, 3)
95
+ idx = torch.cuda.IntTensor(B, N, 3)
96
+
97
+ pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx)
98
+ return torch.sqrt(dist2), idx
99
+
100
+ @staticmethod
101
+ def backward(ctx, a=None, b=None):
102
+ return None, None
103
+
104
+
105
+ three_nn = ThreeNN.apply
106
+
107
+
108
+ class ThreeInterpolate(Function):
109
+
110
+ @staticmethod
111
+ def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
112
+ """
113
+ Performs weight linear interpolation on 3 features
114
+ :param ctx:
115
+ :param features: (B, C, M) Features descriptors to be interpolated from
116
+ :param idx: (B, n, 3) three nearest neighbors of the target features in features
117
+ :param weight: (B, n, 3) weights
118
+ :return:
119
+ output: (B, C, N) tensor of the interpolated features
120
+ """
121
+ assert features.is_contiguous()
122
+ assert idx.is_contiguous()
123
+ assert weight.is_contiguous()
124
+
125
+ B, c, m = features.size()
126
+ n = idx.size(1)
127
+ ctx.three_interpolate_for_backward = (idx, weight, m)
128
+ output = torch.cuda.FloatTensor(B, c, n)
129
+
130
+ pointnet2.three_interpolate_wrapper(B, c, m, n, features, idx, weight, output)
131
+ return output
132
+
133
+ @staticmethod
134
+ def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
135
+ """
136
+ :param ctx:
137
+ :param grad_out: (B, C, N) tensor with gradients of outputs
138
+ :return:
139
+ grad_features: (B, C, M) tensor with gradients of features
140
+ None:
141
+ None:
142
+ """
143
+ idx, weight, m = ctx.three_interpolate_for_backward
144
+ B, c, n = grad_out.size()
145
+
146
+ grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_())
147
+ grad_out_data = grad_out.data.contiguous()
148
+
149
+ pointnet2.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data, idx, weight, grad_features.data)
150
+ return grad_features, None, None
151
+
152
+
153
+ three_interpolate = ThreeInterpolate.apply
154
+
155
+
156
+ class GroupingOperation(Function):
157
+
158
+ @staticmethod
159
+ def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
160
+ """
161
+ :param ctx:
162
+ :param features: (B, C, N) tensor of features to group
163
+ :param idx: (B, npoint, nsample) tensor containing the indicies of features to group with
164
+ :return:
165
+ output: (B, C, npoint, nsample) tensor
166
+ """
167
+ assert features.is_contiguous()
168
+ assert idx.is_contiguous()
169
+
170
+ B, nfeatures, nsample = idx.size()
171
+ _, C, N = features.size()
172
+ output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
173
+
174
+ pointnet2.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output)
175
+
176
+ ctx.for_backwards = (idx, N)
177
+ return output
178
+
179
+ @staticmethod
180
+ def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
181
+ """
182
+ :param ctx:
183
+ :param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward
184
+ :return:
185
+ grad_features: (B, C, N) gradient of the features
186
+ """
187
+ idx, N = ctx.for_backwards
188
+
189
+ B, C, npoint, nsample = grad_out.size()
190
+ grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
191
+
192
+ grad_out_data = grad_out.data.contiguous()
193
+ pointnet2.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data)
194
+ return grad_features, None
195
+
196
+
197
+ grouping_operation = GroupingOperation.apply
198
+
199
+
200
+ class BallQuery(Function):
201
+
202
+ @staticmethod
203
+ def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor:
204
+ """
205
+ :param ctx:
206
+ :param radius: float, radius of the balls
207
+ :param nsample: int, maximum number of features in the balls
208
+ :param xyz: (B, N, 3) xyz coordinates of the features
209
+ :param new_xyz: (B, npoint, 3) centers of the ball query
210
+ :return:
211
+ idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls
212
+ """
213
+ assert new_xyz.is_contiguous()
214
+ assert xyz.is_contiguous()
215
+
216
+ B, N, _ = xyz.size()
217
+ npoint = new_xyz.size(1)
218
+ idx = torch.cuda.IntTensor(B, npoint, nsample).zero_()
219
+
220
+ pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx)
221
+ return idx
222
+
223
+ @staticmethod
224
+ def backward(ctx, a=None):
225
+ return None, None, None, None
226
+
227
+
228
+ ball_query = BallQuery.apply
229
+
230
+
231
+ class QueryAndGroup(nn.Module):
232
+ def __init__(self, radius: float, nsample: int, use_xyz: bool = True):
233
+ """
234
+ :param radius: float, radius of ball
235
+ :param nsample: int, maximum number of features to gather in the ball
236
+ :param use_xyz:
237
+ """
238
+ super().__init__()
239
+ self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
240
+
241
+ def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None) -> Tuple[torch.Tensor]:
242
+ """
243
+ :param xyz: (B, N, 3) xyz coordinates of the features
244
+ :param new_xyz: (B, npoint, 3) centroids
245
+ :param features: (B, C, N) descriptors of the features
246
+ :return:
247
+ new_features: (B, 3 + C, npoint, nsample)
248
+ """
249
+ idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
250
+ xyz_trans = xyz.transpose(1, 2).contiguous()
251
+ grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
252
+ grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
253
+
254
+ if features is not None:
255
+ grouped_features = grouping_operation(features, idx)
256
+ if self.use_xyz:
257
+ new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, C + 3, npoint, nsample)
258
+ else:
259
+ new_features = grouped_features
260
+ else:
261
+ assert self.use_xyz, "Cannot have not features and not use xyz as a feature!"
262
+ new_features = grouped_xyz
263
+
264
+ return new_features
265
+
266
+
267
+ class GroupAll(nn.Module):
268
+ def __init__(self, use_xyz: bool = True):
269
+ super().__init__()
270
+ self.use_xyz = use_xyz
271
+
272
+ def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None):
273
+ """
274
+ :param xyz: (B, N, 3) xyz coordinates of the features
275
+ :param new_xyz: ignored
276
+ :param features: (B, C, N) descriptors of the features
277
+ :return:
278
+ new_features: (B, C + 3, 1, N)
279
+ """
280
+ grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
281
+ if features is not None:
282
+ grouped_features = features.unsqueeze(2)
283
+ if self.use_xyz:
284
+ new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, 3 + C, 1, N)
285
+ else:
286
+ new_features = grouped_features
287
+ else:
288
+ new_features = grouped_xyz
289
+
290
+ return new_features
lib/pointnet2/pytorch_utils.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from typing import List, Tuple
3
+
4
+
5
+ class SharedMLP(nn.Sequential):
6
+
7
+ def __init__(
8
+ self,
9
+ args: List[int],
10
+ *,
11
+ bn: bool = False,
12
+ activation=nn.ReLU(inplace=True),
13
+ preact: bool = False,
14
+ first: bool = False,
15
+ name: str = "",
16
+ instance_norm: bool = False,
17
+ ):
18
+ super().__init__()
19
+
20
+ for i in range(len(args) - 1):
21
+ self.add_module(
22
+ name + 'layer{}'.format(i),
23
+ Conv2d(
24
+ args[i],
25
+ args[i + 1],
26
+ bn=(not first or not preact or (i != 0)) and bn,
27
+ activation=activation
28
+ if (not first or not preact or (i != 0)) else None,
29
+ preact=preact,
30
+ instance_norm=instance_norm
31
+ )
32
+ )
33
+
34
+
35
+ class _ConvBase(nn.Sequential):
36
+
37
+ def __init__(
38
+ self,
39
+ in_size,
40
+ out_size,
41
+ kernel_size,
42
+ stride,
43
+ padding,
44
+ activation,
45
+ bn,
46
+ init,
47
+ conv=None,
48
+ batch_norm=None,
49
+ bias=True,
50
+ preact=False,
51
+ name="",
52
+ instance_norm=False,
53
+ instance_norm_func=None
54
+ ):
55
+ super().__init__()
56
+
57
+ bias = bias and (not bn)
58
+ conv_unit = conv(
59
+ in_size,
60
+ out_size,
61
+ kernel_size=kernel_size,
62
+ stride=stride,
63
+ padding=padding,
64
+ bias=bias
65
+ )
66
+ init(conv_unit.weight)
67
+ if bias:
68
+ nn.init.constant_(conv_unit.bias, 0)
69
+
70
+ if bn:
71
+ if not preact:
72
+ bn_unit = batch_norm(out_size)
73
+ else:
74
+ bn_unit = batch_norm(in_size)
75
+ if instance_norm:
76
+ if not preact:
77
+ in_unit = instance_norm_func(out_size, affine=False, track_running_stats=False)
78
+ else:
79
+ in_unit = instance_norm_func(in_size, affine=False, track_running_stats=False)
80
+
81
+ if preact:
82
+ if bn:
83
+ self.add_module(name + 'bn', bn_unit)
84
+
85
+ if activation is not None:
86
+ self.add_module(name + 'activation', activation)
87
+
88
+ if not bn and instance_norm:
89
+ self.add_module(name + 'in', in_unit)
90
+
91
+ self.add_module(name + 'conv', conv_unit)
92
+
93
+ if not preact:
94
+ if bn:
95
+ self.add_module(name + 'bn', bn_unit)
96
+
97
+ if activation is not None:
98
+ self.add_module(name + 'activation', activation)
99
+
100
+ if not bn and instance_norm:
101
+ self.add_module(name + 'in', in_unit)
102
+
103
+
104
+ class _BNBase(nn.Sequential):
105
+
106
+ def __init__(self, in_size, batch_norm=None, name=""):
107
+ super().__init__()
108
+ self.add_module(name + "bn", batch_norm(in_size))
109
+
110
+ nn.init.constant_(self[0].weight, 1.0)
111
+ nn.init.constant_(self[0].bias, 0)
112
+
113
+
114
+ class BatchNorm1d(_BNBase):
115
+
116
+ def __init__(self, in_size: int, *, name: str = ""):
117
+ super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name)
118
+
119
+
120
+ class BatchNorm2d(_BNBase):
121
+
122
+ def __init__(self, in_size: int, name: str = ""):
123
+ super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name)
124
+
125
+
126
+ class Conv1d(_ConvBase):
127
+
128
+ def __init__(
129
+ self,
130
+ in_size: int,
131
+ out_size: int,
132
+ *,
133
+ kernel_size: int = 1,
134
+ stride: int = 1,
135
+ padding: int = 0,
136
+ activation=nn.ReLU(inplace=True),
137
+ bn: bool = False,
138
+ init=nn.init.kaiming_normal_,
139
+ bias: bool = True,
140
+ preact: bool = False,
141
+ name: str = "",
142
+ instance_norm=False
143
+ ):
144
+ super().__init__(
145
+ in_size,
146
+ out_size,
147
+ kernel_size,
148
+ stride,
149
+ padding,
150
+ activation,
151
+ bn,
152
+ init,
153
+ conv=nn.Conv1d,
154
+ batch_norm=BatchNorm1d,
155
+ bias=bias,
156
+ preact=preact,
157
+ name=name,
158
+ instance_norm=instance_norm,
159
+ instance_norm_func=nn.InstanceNorm1d
160
+ )
161
+
162
+
163
+ class Conv2d(_ConvBase):
164
+
165
+ def __init__(
166
+ self,
167
+ in_size: int,
168
+ out_size: int,
169
+ *,
170
+ kernel_size: Tuple[int, int] = (1, 1),
171
+ stride: Tuple[int, int] = (1, 1),
172
+ padding: Tuple[int, int] = (0, 0),
173
+ activation=nn.ReLU(inplace=True),
174
+ bn: bool = False,
175
+ init=nn.init.kaiming_normal_,
176
+ bias: bool = True,
177
+ preact: bool = False,
178
+ name: str = "",
179
+ instance_norm=False
180
+ ):
181
+ super().__init__(
182
+ in_size,
183
+ out_size,
184
+ kernel_size,
185
+ stride,
186
+ padding,
187
+ activation,
188
+ bn,
189
+ init,
190
+ conv=nn.Conv2d,
191
+ batch_norm=BatchNorm2d,
192
+ bias=bias,
193
+ preact=preact,
194
+ name=name,
195
+ instance_norm=instance_norm,
196
+ instance_norm_func=nn.InstanceNorm2d
197
+ )
198
+
199
+
200
+ class FC(nn.Sequential):
201
+
202
+ def __init__(
203
+ self,
204
+ in_size: int,
205
+ out_size: int,
206
+ *,
207
+ activation=nn.ReLU(inplace=True),
208
+ bn: bool = False,
209
+ init=None,
210
+ preact: bool = False,
211
+ name: str = ""
212
+ ):
213
+ super().__init__()
214
+
215
+ fc = nn.Linear(in_size, out_size, bias=not bn)
216
+ if init is not None:
217
+ init(fc.weight)
218
+ if not bn:
219
+ nn.init.constant(fc.bias, 0)
220
+
221
+ if preact:
222
+ if bn:
223
+ self.add_module(name + 'bn', BatchNorm1d(in_size))
224
+
225
+ if activation is not None:
226
+ self.add_module(name + 'activation', activation)
227
+
228
+ self.add_module(name + 'fc', fc)
229
+
230
+ if not preact:
231
+ if bn:
232
+ self.add_module(name + 'bn', BatchNorm1d(out_size))
233
+
234
+ if activation is not None:
235
+ self.add_module(name + 'activation', activation)
236
+
lib/pointnet2/setup.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup
2
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3
+
4
+ setup(
5
+ name='pointnet2',
6
+ ext_modules=[
7
+ CUDAExtension('pointnet2_cuda', [
8
+ 'src/pointnet2_api.cpp',
9
+
10
+ 'src/ball_query.cpp',
11
+ 'src/ball_query_gpu.cu',
12
+ 'src/group_points.cpp',
13
+ 'src/group_points_gpu.cu',
14
+ 'src/interpolate.cpp',
15
+ 'src/interpolate_gpu.cu',
16
+ 'src/sampling.cpp',
17
+ 'src/sampling_gpu.cu',
18
+ ],
19
+ extra_compile_args={'cxx': ['-g'],
20
+ 'nvcc': ['-O2']})
21
+ ],
22
+ cmdclass={'build_ext': BuildExtension}
23
+ )
lib/pointnet2/src/ball_query.cpp ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/serialize/tensor.h>
2
+ #include <vector>
3
+ #include <ATen/cuda/CUDAContext.h>
4
+ #include <ATen/cuda/CUDAEvent.h>
5
+ #include <cuda.h>
6
+ #include <cuda_runtime_api.h>
7
+ #include "ball_query_gpu.h"
8
+
9
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
10
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
11
+ #define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
12
+
13
+ int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample,
14
+ at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor) {
15
+ CHECK_INPUT(new_xyz_tensor);
16
+ CHECK_INPUT(xyz_tensor);
17
+ const float *new_xyz = new_xyz_tensor.data<float>();
18
+ const float *xyz = xyz_tensor.data<float>();
19
+ int *idx = idx_tensor.data<int>();
20
+
21
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
22
+ ball_query_kernel_launcher_fast(b, n, m, radius, nsample, new_xyz, xyz, idx, stream);
23
+ return 1;
24
+ }
lib/pointnet2/src/ball_query_gpu.cu ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <math.h>
2
+ #include <stdio.h>
3
+ #include <stdlib.h>
4
+
5
+ #include "ball_query_gpu.h"
6
+ #include "cuda_utils.h"
7
+
8
+
9
+ __global__ void ball_query_kernel_fast(int b, int n, int m, float radius, int nsample,
10
+ const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) {
11
+ // new_xyz: (B, M, 3)
12
+ // xyz: (B, N, 3)
13
+ // output:
14
+ // idx: (B, M, nsample)
15
+ int bs_idx = blockIdx.y;
16
+ int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
17
+ if (bs_idx >= b || pt_idx >= m) return;
18
+
19
+ new_xyz += bs_idx * m * 3 + pt_idx * 3;
20
+ xyz += bs_idx * n * 3;
21
+ idx += bs_idx * m * nsample + pt_idx * nsample;
22
+
23
+ float radius2 = radius * radius;
24
+ float new_x = new_xyz[0];
25
+ float new_y = new_xyz[1];
26
+ float new_z = new_xyz[2];
27
+
28
+ int cnt = 0;
29
+ for (int k = 0; k < n; ++k) {
30
+ float x = xyz[k * 3 + 0];
31
+ float y = xyz[k * 3 + 1];
32
+ float z = xyz[k * 3 + 2];
33
+ float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z);
34
+ if (d2 < radius2){
35
+ if (cnt == 0){
36
+ for (int l = 0; l < nsample; ++l) {
37
+ idx[l] = k;
38
+ }
39
+ }
40
+ idx[cnt] = k;
41
+ ++cnt;
42
+ if (cnt >= nsample) break;
43
+ }
44
+ }
45
+ }
46
+
47
+
48
+ void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, \
49
+ const float *new_xyz, const float *xyz, int *idx, cudaStream_t stream) {
50
+ // new_xyz: (B, M, 3)
51
+ // xyz: (B, N, 3)
52
+ // output:
53
+ // idx: (B, M, nsample)
54
+
55
+ cudaError_t err;
56
+
57
+ dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row)
58
+ dim3 threads(THREADS_PER_BLOCK);
59
+
60
+ ball_query_kernel_fast<<<blocks, threads, 0, stream>>>(b, n, m, radius, nsample, new_xyz, xyz, idx);
61
+ // cudaDeviceSynchronize(); // for using printf in kernel function
62
+ err = cudaGetLastError();
63
+ if (cudaSuccess != err) {
64
+ fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
65
+ exit(-1);
66
+ }
67
+ }
lib/pointnet2/src/ball_query_gpu.h ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _BALL_QUERY_GPU_H
2
+ #define _BALL_QUERY_GPU_H
3
+
4
+ #include <torch/serialize/tensor.h>
5
+ #include <vector>
6
+ #include <cuda.h>
7
+ #include <cuda_runtime_api.h>
8
+
9
+ int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample,
10
+ at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor);
11
+
12
+ void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample,
13
+ const float *xyz, const float *new_xyz, int *idx, cudaStream_t stream);
14
+
15
+ #endif
lib/pointnet2/src/cuda_utils.h ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _CUDA_UTILS_H
2
+ #define _CUDA_UTILS_H
3
+
4
+ #include <cmath>
5
+
6
+ #define TOTAL_THREADS 1024
7
+ #define THREADS_PER_BLOCK 256
8
+ #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
9
+
10
+ inline int opt_n_threads(int work_size) {
11
+ const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);
12
+
13
+ return max(min(1 << pow_2, TOTAL_THREADS), 1);
14
+ }
15
+ #endif
lib/pointnet2/src/group_points.cpp ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/serialize/tensor.h>
2
+ #include <cuda.h>
3
+ #include <cuda_runtime_api.h>
4
+ #include <vector>
5
+ #include "group_points_gpu.h"
6
+ #include <ATen/cuda/CUDAContext.h>
7
+ #include <ATen/cuda/CUDAEvent.h>
8
+
9
+
10
+
11
+ int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample,
12
+ at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) {
13
+
14
+ float *grad_points = grad_points_tensor.data<float>();
15
+ const int *idx = idx_tensor.data<int>();
16
+ const float *grad_out = grad_out_tensor.data<float>();
17
+
18
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
19
+ group_points_grad_kernel_launcher_fast(b, c, n, npoints, nsample, grad_out, idx, grad_points, stream);
20
+ return 1;
21
+ }
22
+
23
+
24
+ int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample,
25
+ at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor) {
26
+
27
+ const float *points = points_tensor.data<float>();
28
+ const int *idx = idx_tensor.data<int>();
29
+ float *out = out_tensor.data<float>();
30
+
31
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
32
+ group_points_kernel_launcher_fast(b, c, n, npoints, nsample, points, idx, out, stream);
33
+ return 1;
34
+ }
lib/pointnet2/src/group_points_gpu.cu ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stdio.h>
2
+ #include <stdlib.h>
3
+
4
+ #include "cuda_utils.h"
5
+ #include "group_points_gpu.h"
6
+
7
+
8
+ __global__ void group_points_grad_kernel_fast(int b, int c, int n, int npoints, int nsample,
9
+ const float *__restrict__ grad_out, const int *__restrict__ idx, float *__restrict__ grad_points) {
10
+ // grad_out: (B, C, npoints, nsample)
11
+ // idx: (B, npoints, nsample)
12
+ // output:
13
+ // grad_points: (B, C, N)
14
+ int bs_idx = blockIdx.z;
15
+ int c_idx = blockIdx.y;
16
+ int index = blockIdx.x * blockDim.x + threadIdx.x;
17
+ int pt_idx = index / nsample;
18
+ if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;
19
+
20
+ int sample_idx = index % nsample;
21
+ grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx;
22
+ idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx;
23
+
24
+ atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0] , grad_out[0]);
25
+ }
26
+
27
+ void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample,
28
+ const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) {
29
+ // grad_out: (B, C, npoints, nsample)
30
+ // idx: (B, npoints, nsample)
31
+ // output:
32
+ // grad_points: (B, C, N)
33
+ cudaError_t err;
34
+ dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
35
+ dim3 threads(THREADS_PER_BLOCK);
36
+
37
+ group_points_grad_kernel_fast<<<blocks, threads, 0, stream>>>(b, c, n, npoints, nsample, grad_out, idx, grad_points);
38
+
39
+ err = cudaGetLastError();
40
+ if (cudaSuccess != err) {
41
+ fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
42
+ exit(-1);
43
+ }
44
+ }
45
+
46
+
47
+ __global__ void group_points_kernel_fast(int b, int c, int n, int npoints, int nsample,
48
+ const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) {
49
+ // points: (B, C, N)
50
+ // idx: (B, npoints, nsample)
51
+ // output:
52
+ // out: (B, C, npoints, nsample)
53
+ int bs_idx = blockIdx.z;
54
+ int c_idx = blockIdx.y;
55
+ int index = blockIdx.x * blockDim.x + threadIdx.x;
56
+ int pt_idx = index / nsample;
57
+ if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;
58
+
59
+ int sample_idx = index % nsample;
60
+
61
+ idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx;
62
+ int in_idx = bs_idx * c * n + c_idx * n + idx[0];
63
+ int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx;
64
+
65
+ out[out_idx] = points[in_idx];
66
+ }
67
+
68
+
69
+ void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample,
70
+ const float *points, const int *idx, float *out, cudaStream_t stream) {
71
+ // points: (B, C, N)
72
+ // idx: (B, npoints, nsample)
73
+ // output:
74
+ // out: (B, C, npoints, nsample)
75
+ cudaError_t err;
76
+ dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
77
+ dim3 threads(THREADS_PER_BLOCK);
78
+
79
+ group_points_kernel_fast<<<blocks, threads, 0, stream>>>(b, c, n, npoints, nsample, points, idx, out);
80
+ // cudaDeviceSynchronize(); // for using printf in kernel function
81
+ err = cudaGetLastError();
82
+ if (cudaSuccess != err) {
83
+ fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
84
+ exit(-1);
85
+ }
86
+ }
lib/pointnet2/src/group_points_gpu.h ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _GROUP_POINTS_GPU_H
2
+ #define _GROUP_POINTS_GPU_H
3
+
4
+ #include <torch/serialize/tensor.h>
5
+ #include <cuda.h>
6
+ #include <cuda_runtime_api.h>
7
+ #include <vector>
8
+
9
+
10
+ int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample,
11
+ at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor);
12
+
13
+ void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample,
14
+ const float *points, const int *idx, float *out, cudaStream_t stream);
15
+
16
+ int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample,
17
+ at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor);
18
+
19
+ void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample,
20
+ const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream);
21
+
22
+ #endif
lib/pointnet2/src/interpolate.cpp ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/serialize/tensor.h>
2
+ #include <vector>
3
+ #include <math.h>
4
+ #include <stdio.h>
5
+ #include <stdlib.h>
6
+ #include <cuda.h>
7
+ #include <cuda_runtime_api.h>
8
+ #include "interpolate_gpu.h"
9
+ #include <ATen/cuda/CUDAContext.h>
10
+ #include <ATen/cuda/CUDAEvent.h>
11
+
12
+
13
+ void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor,
14
+ at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) {
15
+ const float *unknown = unknown_tensor.data<float>();
16
+ const float *known = known_tensor.data<float>();
17
+ float *dist2 = dist2_tensor.data<float>();
18
+ int *idx = idx_tensor.data<int>();
19
+
20
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
21
+ three_nn_kernel_launcher_fast(b, n, m, unknown, known, dist2, idx, stream);
22
+ }
23
+
24
+
25
+ void three_interpolate_wrapper_fast(int b, int c, int m, int n,
26
+ at::Tensor points_tensor,
27
+ at::Tensor idx_tensor,
28
+ at::Tensor weight_tensor,
29
+ at::Tensor out_tensor) {
30
+
31
+ const float *points = points_tensor.data<float>();
32
+ const float *weight = weight_tensor.data<float>();
33
+ float *out = out_tensor.data<float>();
34
+ const int *idx = idx_tensor.data<int>();
35
+
36
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
37
+ three_interpolate_kernel_launcher_fast(b, c, m, n, points, idx, weight, out, stream);
38
+ }
39
+
40
+ void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m,
41
+ at::Tensor grad_out_tensor,
42
+ at::Tensor idx_tensor,
43
+ at::Tensor weight_tensor,
44
+ at::Tensor grad_points_tensor) {
45
+
46
+ const float *grad_out = grad_out_tensor.data<float>();
47
+ const float *weight = weight_tensor.data<float>();
48
+ float *grad_points = grad_points_tensor.data<float>();
49
+ const int *idx = idx_tensor.data<int>();
50
+
51
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
52
+ three_interpolate_grad_kernel_launcher_fast(b, c, n, m, grad_out, idx, weight, grad_points, stream);
53
+ }
lib/pointnet2/src/interpolate_gpu.cu ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <math.h>
2
+ #include <stdio.h>
3
+ #include <stdlib.h>
4
+
5
+ #include "cuda_utils.h"
6
+ #include "interpolate_gpu.h"
7
+
8
+
9
+ __global__ void three_nn_kernel_fast(int b, int n, int m, const float *__restrict__ unknown,
10
+ const float *__restrict__ known, float *__restrict__ dist2, int *__restrict__ idx) {
11
+ // unknown: (B, N, 3)
12
+ // known: (B, M, 3)
13
+ // output:
14
+ // dist2: (B, N, 3)
15
+ // idx: (B, N, 3)
16
+
17
+ int bs_idx = blockIdx.y;
18
+ int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
19
+ if (bs_idx >= b || pt_idx >= n) return;
20
+
21
+ unknown += bs_idx * n * 3 + pt_idx * 3;
22
+ known += bs_idx * m * 3;
23
+ dist2 += bs_idx * n * 3 + pt_idx * 3;
24
+ idx += bs_idx * n * 3 + pt_idx * 3;
25
+
26
+ float ux = unknown[0];
27
+ float uy = unknown[1];
28
+ float uz = unknown[2];
29
+
30
+ double best1 = 1e40, best2 = 1e40, best3 = 1e40;
31
+ int besti1 = 0, besti2 = 0, besti3 = 0;
32
+ for (int k = 0; k < m; ++k) {
33
+ float x = known[k * 3 + 0];
34
+ float y = known[k * 3 + 1];
35
+ float z = known[k * 3 + 2];
36
+ float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
37
+ if (d < best1) {
38
+ best3 = best2; besti3 = besti2;
39
+ best2 = best1; besti2 = besti1;
40
+ best1 = d; besti1 = k;
41
+ }
42
+ else if (d < best2) {
43
+ best3 = best2; besti3 = besti2;
44
+ best2 = d; besti2 = k;
45
+ }
46
+ else if (d < best3) {
47
+ best3 = d; besti3 = k;
48
+ }
49
+ }
50
+ dist2[0] = best1; dist2[1] = best2; dist2[2] = best3;
51
+ idx[0] = besti1; idx[1] = besti2; idx[2] = besti3;
52
+ }
53
+
54
+
55
+ void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown,
56
+ const float *known, float *dist2, int *idx, cudaStream_t stream) {
57
+ // unknown: (B, N, 3)
58
+ // known: (B, M, 3)
59
+ // output:
60
+ // dist2: (B, N, 3)
61
+ // idx: (B, N, 3)
62
+
63
+ cudaError_t err;
64
+ dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row)
65
+ dim3 threads(THREADS_PER_BLOCK);
66
+
67
+ three_nn_kernel_fast<<<blocks, threads, 0, stream>>>(b, n, m, unknown, known, dist2, idx);
68
+
69
+ err = cudaGetLastError();
70
+ if (cudaSuccess != err) {
71
+ fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
72
+ exit(-1);
73
+ }
74
+ }
75
+
76
+
77
+ __global__ void three_interpolate_kernel_fast(int b, int c, int m, int n, const float *__restrict__ points,
78
+ const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ out) {
79
+ // points: (B, C, M)
80
+ // idx: (B, N, 3)
81
+ // weight: (B, N, 3)
82
+ // output:
83
+ // out: (B, C, N)
84
+
85
+ int bs_idx = blockIdx.z;
86
+ int c_idx = blockIdx.y;
87
+ int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
88
+
89
+ if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
90
+
91
+ weight += bs_idx * n * 3 + pt_idx * 3;
92
+ points += bs_idx * c * m + c_idx * m;
93
+ idx += bs_idx * n * 3 + pt_idx * 3;
94
+ out += bs_idx * c * n + c_idx * n;
95
+
96
+ out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] + weight[2] * points[idx[2]];
97
+ }
98
+
99
+ void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n,
100
+ const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream) {
101
+ // points: (B, C, M)
102
+ // idx: (B, N, 3)
103
+ // weight: (B, N, 3)
104
+ // output:
105
+ // out: (B, C, N)
106
+
107
+ cudaError_t err;
108
+ dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
109
+ dim3 threads(THREADS_PER_BLOCK);
110
+ three_interpolate_kernel_fast<<<blocks, threads, 0, stream>>>(b, c, m, n, points, idx, weight, out);
111
+
112
+ err = cudaGetLastError();
113
+ if (cudaSuccess != err) {
114
+ fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
115
+ exit(-1);
116
+ }
117
+ }
118
+
119
+
120
+ __global__ void three_interpolate_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out,
121
+ const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ grad_points) {
122
+ // grad_out: (B, C, N)
123
+ // weight: (B, N, 3)
124
+ // output:
125
+ // grad_points: (B, C, M)
126
+
127
+ int bs_idx = blockIdx.z;
128
+ int c_idx = blockIdx.y;
129
+ int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
130
+
131
+ if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
132
+
133
+ grad_out += bs_idx * c * n + c_idx * n + pt_idx;
134
+ weight += bs_idx * n * 3 + pt_idx * 3;
135
+ grad_points += bs_idx * c * m + c_idx * m;
136
+ idx += bs_idx * n * 3 + pt_idx * 3;
137
+
138
+
139
+ atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]);
140
+ atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]);
141
+ atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]);
142
+ }
143
+
144
+ void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out,
145
+ const int *idx, const float *weight, float *grad_points, cudaStream_t stream) {
146
+ // grad_out: (B, C, N)
147
+ // weight: (B, N, 3)
148
+ // output:
149
+ // grad_points: (B, C, M)
150
+
151
+ cudaError_t err;
152
+ dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
153
+ dim3 threads(THREADS_PER_BLOCK);
154
+ three_interpolate_grad_kernel_fast<<<blocks, threads, 0, stream>>>(b, c, n, m, grad_out, idx, weight, grad_points);
155
+
156
+ err = cudaGetLastError();
157
+ if (cudaSuccess != err) {
158
+ fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
159
+ exit(-1);
160
+ }
161
+ }
lib/pointnet2/src/interpolate_gpu.h ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _INTERPOLATE_GPU_H
2
+ #define _INTERPOLATE_GPU_H
3
+
4
+ #include <torch/serialize/tensor.h>
5
+ #include<vector>
6
+ #include <cuda.h>
7
+ #include <cuda_runtime_api.h>
8
+
9
+
10
+ void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor,
11
+ at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor);
12
+
13
+ void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown,
14
+ const float *known, float *dist2, int *idx, cudaStream_t stream);
15
+
16
+
17
+ void three_interpolate_wrapper_fast(int b, int c, int m, int n, at::Tensor points_tensor,
18
+ at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor);
19
+
20
+ void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n,
21
+ const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream);
22
+
23
+
24
+ void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, at::Tensor grad_out_tensor,
25
+ at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_points_tensor);
26
+
27
+ void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out,
28
+ const int *idx, const float *weight, float *grad_points, cudaStream_t stream);
29
+
30
+ #endif
lib/pointnet2/src/pointnet2_api.cpp ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/serialize/tensor.h>
2
+ #include <torch/extension.h>
3
+
4
+ #include "ball_query_gpu.h"
5
+ #include "group_points_gpu.h"
6
+ #include "sampling_gpu.h"
7
+ #include "interpolate_gpu.h"
8
+
9
+
10
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
11
+ m.def("ball_query_wrapper", &ball_query_wrapper_fast, "ball_query_wrapper_fast");
12
+
13
+ m.def("group_points_wrapper", &group_points_wrapper_fast, "group_points_wrapper_fast");
14
+ m.def("group_points_grad_wrapper", &group_points_grad_wrapper_fast, "group_points_grad_wrapper_fast");
15
+
16
+ m.def("gather_points_wrapper", &gather_points_wrapper_fast, "gather_points_wrapper_fast");
17
+ m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper_fast, "gather_points_grad_wrapper_fast");
18
+
19
+ m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper");
20
+
21
+ m.def("three_nn_wrapper", &three_nn_wrapper_fast, "three_nn_wrapper_fast");
22
+ m.def("three_interpolate_wrapper", &three_interpolate_wrapper_fast, "three_interpolate_wrapper_fast");
23
+ m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_fast, "three_interpolate_grad_wrapper_fast");
24
+ }
lib/pointnet2/src/sampling.cpp ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/serialize/tensor.h>
2
+ #include <ATen/cuda/CUDAContext.h>
3
+ #include <vector>
4
+ #include <ATen/cuda/CUDAContext.h>
5
+ #include <ATen/cuda/CUDAEvent.h>
6
+ #include "sampling_gpu.h"
7
+
8
+
9
+
10
+ int gather_points_wrapper_fast(int b, int c, int n, int npoints,
11
+ at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor){
12
+ const float *points = points_tensor.data<float>();
13
+ const int *idx = idx_tensor.data<int>();
14
+ float *out = out_tensor.data<float>();
15
+
16
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
17
+ gather_points_kernel_launcher_fast(b, c, n, npoints, points, idx, out, stream);
18
+ return 1;
19
+ }
20
+
21
+
22
+ int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints,
23
+ at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) {
24
+
25
+ const float *grad_out = grad_out_tensor.data<float>();
26
+ const int *idx = idx_tensor.data<int>();
27
+ float *grad_points = grad_points_tensor.data<float>();
28
+
29
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
30
+ gather_points_grad_kernel_launcher_fast(b, c, n, npoints, grad_out, idx, grad_points, stream);
31
+ return 1;
32
+ }
33
+
34
+
35
+ int furthest_point_sampling_wrapper(int b, int n, int m,
36
+ at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) {
37
+
38
+ const float *points = points_tensor.data<float>();
39
+ float *temp = temp_tensor.data<float>();
40
+ int *idx = idx_tensor.data<int>();
41
+
42
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
43
+ furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx, stream);
44
+ return 1;
45
+ }
lib/pointnet2/src/sampling_gpu.cu ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stdio.h>
2
+ #include <stdlib.h>
3
+
4
+ #include "cuda_utils.h"
5
+ #include "sampling_gpu.h"
6
+
7
+
8
+ __global__ void gather_points_kernel_fast(int b, int c, int n, int m,
9
+ const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) {
10
+ // points: (B, C, N)
11
+ // idx: (B, M)
12
+ // output:
13
+ // out: (B, C, M)
14
+
15
+ int bs_idx = blockIdx.z;
16
+ int c_idx = blockIdx.y;
17
+ int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
18
+ if (bs_idx >= b || c_idx >= c || pt_idx >= m) return;
19
+
20
+ out += bs_idx * c * m + c_idx * m + pt_idx;
21
+ idx += bs_idx * m + pt_idx;
22
+ points += bs_idx * c * n + c_idx * n;
23
+ out[0] = points[idx[0]];
24
+ }
25
+
26
+ void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints,
27
+ const float *points, const int *idx, float *out, cudaStream_t stream) {
28
+ // points: (B, C, N)
29
+ // idx: (B, npoints)
30
+ // output:
31
+ // out: (B, C, npoints)
32
+
33
+ cudaError_t err;
34
+ dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
35
+ dim3 threads(THREADS_PER_BLOCK);
36
+
37
+ gather_points_kernel_fast<<<blocks, threads, 0, stream>>>(b, c, n, npoints, points, idx, out);
38
+
39
+ err = cudaGetLastError();
40
+ if (cudaSuccess != err) {
41
+ fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
42
+ exit(-1);
43
+ }
44
+ }
45
+
46
+ __global__ void gather_points_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out,
47
+ const int *__restrict__ idx, float *__restrict__ grad_points) {
48
+ // grad_out: (B, C, M)
49
+ // idx: (B, M)
50
+ // output:
51
+ // grad_points: (B, C, N)
52
+
53
+ int bs_idx = blockIdx.z;
54
+ int c_idx = blockIdx.y;
55
+ int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
56
+ if (bs_idx >= b || c_idx >= c || pt_idx >= m) return;
57
+
58
+ grad_out += bs_idx * c * m + c_idx * m + pt_idx;
59
+ idx += bs_idx * m + pt_idx;
60
+ grad_points += bs_idx * c * n + c_idx * n;
61
+
62
+ atomicAdd(grad_points + idx[0], grad_out[0]);
63
+ }
64
+
65
+ void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints,
66
+ const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) {
67
+ // grad_out: (B, C, npoints)
68
+ // idx: (B, npoints)
69
+ // output:
70
+ // grad_points: (B, C, N)
71
+
72
+ cudaError_t err;
73
+ dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
74
+ dim3 threads(THREADS_PER_BLOCK);
75
+
76
+ gather_points_grad_kernel_fast<<<blocks, threads, 0, stream>>>(b, c, n, npoints, grad_out, idx, grad_points);
77
+
78
+ err = cudaGetLastError();
79
+ if (cudaSuccess != err) {
80
+ fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
81
+ exit(-1);
82
+ }
83
+ }
84
+
85
+
86
+ __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, int idx1, int idx2){
87
+ const float v1 = dists[idx1], v2 = dists[idx2];
88
+ const int i1 = dists_i[idx1], i2 = dists_i[idx2];
89
+ dists[idx1] = max(v1, v2);
90
+ dists_i[idx1] = v2 > v1 ? i2 : i1;
91
+ }
92
+
93
+ template <unsigned int block_size>
94
+ __global__ void furthest_point_sampling_kernel(int b, int n, int m,
95
+ const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) {
96
+ // dataset: (B, N, 3)
97
+ // tmp: (B, N)
98
+ // output:
99
+ // idx: (B, M)
100
+
101
+ if (m <= 0) return;
102
+ __shared__ float dists[block_size];
103
+ __shared__ int dists_i[block_size];
104
+
105
+ int batch_index = blockIdx.x;
106
+ dataset += batch_index * n * 3;
107
+ temp += batch_index * n;
108
+ idxs += batch_index * m;
109
+
110
+ int tid = threadIdx.x;
111
+ const int stride = block_size;
112
+
113
+ int old = 0;
114
+ if (threadIdx.x == 0)
115
+ idxs[0] = old;
116
+
117
+ __syncthreads();
118
+ for (int j = 1; j < m; j++) {
119
+ int besti = 0;
120
+ float best = -1;
121
+ float x1 = dataset[old * 3 + 0];
122
+ float y1 = dataset[old * 3 + 1];
123
+ float z1 = dataset[old * 3 + 2];
124
+ for (int k = tid; k < n; k += stride) {
125
+ float x2, y2, z2;
126
+ x2 = dataset[k * 3 + 0];
127
+ y2 = dataset[k * 3 + 1];
128
+ z2 = dataset[k * 3 + 2];
129
+ // float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
130
+ // if (mag <= 1e-3)
131
+ // continue;
132
+
133
+ float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
134
+ float d2 = min(d, temp[k]);
135
+ temp[k] = d2;
136
+ besti = d2 > best ? k : besti;
137
+ best = d2 > best ? d2 : best;
138
+ }
139
+ dists[tid] = best;
140
+ dists_i[tid] = besti;
141
+ __syncthreads();
142
+
143
+ if (block_size >= 1024) {
144
+ if (tid < 512) {
145
+ __update(dists, dists_i, tid, tid + 512);
146
+ }
147
+ __syncthreads();
148
+ }
149
+
150
+ if (block_size >= 512) {
151
+ if (tid < 256) {
152
+ __update(dists, dists_i, tid, tid + 256);
153
+ }
154
+ __syncthreads();
155
+ }
156
+ if (block_size >= 256) {
157
+ if (tid < 128) {
158
+ __update(dists, dists_i, tid, tid + 128);
159
+ }
160
+ __syncthreads();
161
+ }
162
+ if (block_size >= 128) {
163
+ if (tid < 64) {
164
+ __update(dists, dists_i, tid, tid + 64);
165
+ }
166
+ __syncthreads();
167
+ }
168
+ if (block_size >= 64) {
169
+ if (tid < 32) {
170
+ __update(dists, dists_i, tid, tid + 32);
171
+ }
172
+ __syncthreads();
173
+ }
174
+ if (block_size >= 32) {
175
+ if (tid < 16) {
176
+ __update(dists, dists_i, tid, tid + 16);
177
+ }
178
+ __syncthreads();
179
+ }
180
+ if (block_size >= 16) {
181
+ if (tid < 8) {
182
+ __update(dists, dists_i, tid, tid + 8);
183
+ }
184
+ __syncthreads();
185
+ }
186
+ if (block_size >= 8) {
187
+ if (tid < 4) {
188
+ __update(dists, dists_i, tid, tid + 4);
189
+ }
190
+ __syncthreads();
191
+ }
192
+ if (block_size >= 4) {
193
+ if (tid < 2) {
194
+ __update(dists, dists_i, tid, tid + 2);
195
+ }
196
+ __syncthreads();
197
+ }
198
+ if (block_size >= 2) {
199
+ if (tid < 1) {
200
+ __update(dists, dists_i, tid, tid + 1);
201
+ }
202
+ __syncthreads();
203
+ }
204
+
205
+ old = dists_i[0];
206
+ if (tid == 0)
207
+ idxs[j] = old;
208
+ }
209
+ }
210
+
211
+ void furthest_point_sampling_kernel_launcher(int b, int n, int m,
212
+ const float *dataset, float *temp, int *idxs, cudaStream_t stream) {
213
+ // dataset: (B, N, 3)
214
+ // tmp: (B, N)
215
+ // output:
216
+ // idx: (B, M)
217
+
218
+ cudaError_t err;
219
+ unsigned int n_threads = opt_n_threads(n);
220
+
221
+ switch (n_threads) {
222
+ case 1024:
223
+ furthest_point_sampling_kernel<1024><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
224
+ case 512:
225
+ furthest_point_sampling_kernel<512><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
226
+ case 256:
227
+ furthest_point_sampling_kernel<256><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
228
+ case 128:
229
+ furthest_point_sampling_kernel<128><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
230
+ case 64:
231
+ furthest_point_sampling_kernel<64><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
232
+ case 32:
233
+ furthest_point_sampling_kernel<32><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
234
+ case 16:
235
+ furthest_point_sampling_kernel<16><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
236
+ case 8:
237
+ furthest_point_sampling_kernel<8><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
238
+ case 4:
239
+ furthest_point_sampling_kernel<4><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
240
+ case 2:
241
+ furthest_point_sampling_kernel<2><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
242
+ case 1:
243
+ furthest_point_sampling_kernel<1><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
244
+ default:
245
+ furthest_point_sampling_kernel<512><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
246
+ }
247
+
248
+ err = cudaGetLastError();
249
+ if (cudaSuccess != err) {
250
+ fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
251
+ exit(-1);
252
+ }
253
+ }
lib/pointnet2/src/sampling_gpu.h ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _SAMPLING_GPU_H
2
+ #define _SAMPLING_GPU_H
3
+
4
+ #include <torch/serialize/tensor.h>
5
+ #include <ATen/cuda/CUDAContext.h>
6
+ #include<vector>
7
+
8
+
9
+ int gather_points_wrapper_fast(int b, int c, int n, int npoints,
10
+ at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor);
11
+
12
+ void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints,
13
+ const float *points, const int *idx, float *out, cudaStream_t stream);
14
+
15
+
16
+ int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints,
17
+ at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor);
18
+
19
+ void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints,
20
+ const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream);
21
+
22
+
23
+ int furthest_point_sampling_wrapper(int b, int n, int m,
24
+ at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor);
25
+
26
+ void furthest_point_sampling_kernel_launcher(int b, int n, int m,
27
+ const float *dataset, float *temp, int *idxs, cudaStream_t stream);
28
+
29
+ #endif
model/LLM/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import onellm
model/LLM/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (189 Bytes). View file
 
model/LLM/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (187 Bytes). View file
 
model/LLM/__pycache__/onellm.cpython-310.pyc ADDED
Binary file (14 kB). View file
 
model/LLM/__pycache__/onellm.cpython-39.pyc ADDED
Binary file (13.9 kB). View file
 
model/LLM/onellm.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the GNU General Public License version 3.
3
+
4
+ from typing import Optional, Tuple
5
+ from dataclasses import dataclass
6
+ import math
7
+ import functools
8
+ import copy
9
+
10
+ import torch
11
+ from torch import nn
12
+ import torch.nn.functional as F
13
+
14
+ import fairscale.nn.model_parallel.initialize as fs_init
15
+ from fairscale.nn.model_parallel.layers import (
16
+ ParallelEmbedding,
17
+ RowParallelLinear,
18
+ ColumnParallelLinear,
19
+ )
20
+ from ..components import RMSNorm
21
+ from flash_attn import flash_attn_func
22
+
23
+ import open_clip
24
+
25
+
26
+ default_linear_init = nn.init.xavier_uniform_
27
+
28
+
29
+ @dataclass
30
+ class ModelArgs:
31
+ dim: int = 512
32
+ n_layers: int = 8
33
+ n_heads: int = 8
34
+ vocab_size: int = -1 # defined later by tokenizer
35
+ multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
36
+ norm_eps: float = 1e-5
37
+
38
+ max_batch_size: int = 32
39
+ max_seq_len: int = 2048
40
+
41
+
42
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
43
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)
44
+ [: (dim // 2)].float() / dim))
45
+ t = torch.arange(end, device=freqs.device) # type: ignore
46
+ freqs = torch.outer(t, freqs).float() # type: ignore
47
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
48
+ return freqs_cis
49
+
50
+
51
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
52
+ ndim = x.ndim
53
+ assert 0 <= 1 < ndim
54
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
55
+ shape = [d if i == 1 or i == ndim -
56
+ 1 else 1 for i, d in enumerate(x.shape)]
57
+ return freqs_cis.view(*shape)
58
+
59
+
60
+ def apply_rotary_emb(
61
+ xq: torch.Tensor,
62
+ xk: torch.Tensor,
63
+ freqs_cis: torch.Tensor,
64
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
65
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
66
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
67
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
68
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
69
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
70
+ return xq_out.type_as(xq), xk_out.type_as(xk)
71
+
72
+
73
+ class Attention(nn.Module):
74
+ def __init__(self, args: ModelArgs):
75
+ super().__init__()
76
+
77
+ self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()
78
+ self.head_dim = args.dim // args.n_heads
79
+
80
+ self.wq = ColumnParallelLinear(
81
+ args.dim,
82
+ args.n_heads * self.head_dim,
83
+ bias=False,
84
+ gather_output=False,
85
+ init_method=default_linear_init,
86
+ )
87
+ self.wk = ColumnParallelLinear(
88
+ args.dim,
89
+ args.n_heads * self.head_dim,
90
+ bias=False,
91
+ gather_output=False,
92
+ init_method=default_linear_init,
93
+ )
94
+ self.wv = ColumnParallelLinear(
95
+ args.dim,
96
+ args.n_heads * self.head_dim,
97
+ bias=False,
98
+ gather_output=False,
99
+ init_method=default_linear_init,
100
+ )
101
+ self.wo = RowParallelLinear(
102
+ args.n_heads * self.head_dim,
103
+ args.dim,
104
+ bias=False,
105
+ input_is_parallel=True,
106
+ init_method=default_linear_init,
107
+ )
108
+
109
+ self.flash = True
110
+ self.k_cache, self.v_cache = None, None
111
+
112
+ def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], prompt=None):
113
+ bsz, seqlen, _ = x.shape
114
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
115
+
116
+ xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
117
+ xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
118
+ xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
119
+
120
+ if freqs_cis is not None:
121
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
122
+
123
+ if self.k_cache is None or self.v_cache is None:
124
+ keys, values = xk, xv
125
+ else:
126
+ self.k_cache = self.k_cache.to(xk)
127
+ self.v_cache = self.v_cache.to(xv)
128
+ self.k_cache[:bsz, start_pos: start_pos + seqlen, :, :] = xk
129
+ self.v_cache[:bsz, start_pos: start_pos + seqlen, :, :] = xv
130
+ keys = self.k_cache[:bsz, :start_pos + seqlen]
131
+ values = self.v_cache[:bsz, :start_pos + seqlen]
132
+
133
+ output = flash_attn_func(
134
+ xq, keys, values, dropout_p=0.0, causal=mask is not None)
135
+ output = output.contiguous().view(bsz, seqlen, -1)
136
+
137
+ return self.wo(output)
138
+
139
+ def allocate_kv_cache(self, max_batch_size: int, max_seq_len: int) -> None:
140
+ kv_cache_shape = (max_batch_size, max_seq_len,
141
+ self.n_local_heads, self.head_dim)
142
+ if self.k_cache is None or self.k_cache.size() != kv_cache_shape:
143
+ self.k_cache = torch.empty(kv_cache_shape)
144
+ if self.v_cache is None or self.v_cache.size() != kv_cache_shape:
145
+ self.v_cache = torch.empty(kv_cache_shape)
146
+
147
+ def destroy_kv_cache(self) -> None:
148
+ self.k_cache, self.v_cache = None, None
149
+
150
+
151
+ class FeedForward(nn.Module):
152
+ def __init__(
153
+ self,
154
+ dim: int,
155
+ hidden_dim: int,
156
+ multiple_of: int,
157
+ ):
158
+ super().__init__()
159
+ hidden_dim = int(2 * hidden_dim / 3)
160
+ hidden_dim = multiple_of * \
161
+ ((hidden_dim + multiple_of - 1) // multiple_of)
162
+
163
+ self.w1 = ColumnParallelLinear(
164
+ dim, hidden_dim, bias=False, gather_output=False, init_method=default_linear_init,
165
+ )
166
+ self.w2 = RowParallelLinear(
167
+ hidden_dim, dim, bias=False, input_is_parallel=True, init_method=default_linear_init
168
+ )
169
+ self.w3 = ColumnParallelLinear(
170
+ dim, hidden_dim, bias=False, gather_output=False, init_method=default_linear_init
171
+ )
172
+
173
+ def _silu_gating(self, x, y):
174
+ return F.silu(x) * y
175
+
176
+ def forward(self, x):
177
+ return self.w2(self._silu_gating(self.w1(x), self.w3(x)))
178
+
179
+
180
+ class TransformerBlock(nn.Module):
181
+ def __init__(self, layer_id: int, args: ModelArgs):
182
+ super().__init__()
183
+ self.n_heads = args.n_heads
184
+ self.dim = args.dim
185
+ self.head_dim = args.dim // args.n_heads
186
+ self.attention = Attention(args)
187
+ self.feed_forward = FeedForward(
188
+ dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of
189
+ )
190
+ self.layer_id = layer_id
191
+ self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
192
+ self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
193
+
194
+ def _forward_ffn(self, h):
195
+ return h + self.feed_forward(self.ffn_norm(h))
196
+
197
+ def _forward_attention(self, x, start_pos, freqs_cis, mask, prompt):
198
+ return x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask, prompt)
199
+
200
+ def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], prompt=None):
201
+ h = self._forward_attention(x, start_pos, freqs_cis, mask, prompt)
202
+ out = self._forward_ffn(h)
203
+ return out
204
+
205
+
206
+ class Mlp(nn.Module):
207
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
208
+ """
209
+
210
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU):
211
+ super().__init__()
212
+ out_features = out_features or in_features
213
+ hidden_features = hidden_features or in_features
214
+
215
+ self.fc1 = nn.Linear(in_features, hidden_features)
216
+ self.act = act_layer()
217
+ self.fc2 = nn.Linear(hidden_features, out_features)
218
+
219
+ def forward(self, x):
220
+ x = self.fc1(x)
221
+ x = self.act(x)
222
+ x = self.fc2(x)
223
+ return x
224
+
225
+
226
+ class Transformer(nn.Module):
227
+ def __init__(self, params: ModelArgs):
228
+ super().__init__()
229
+ self.params = params
230
+ self.vocab_size = params.vocab_size
231
+ self.n_layers = params.n_layers
232
+ self.tok_embeddings = ParallelEmbedding(
233
+ params.vocab_size, params.dim, init_method=nn.init.normal_,
234
+ )
235
+
236
+ self.layers = torch.nn.ModuleList()
237
+ for layer_id in range(params.n_layers):
238
+ self.layers.append(TransformerBlock(layer_id, params))
239
+
240
+ self.norm = RMSNorm(params.dim, eps=params.norm_eps)
241
+ self.output = ColumnParallelLinear(
242
+ params.dim, params.vocab_size, bias=False, init_method=default_linear_init,
243
+ )
244
+
245
+ self.freqs_cis = precompute_freqs_cis(
246
+ self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
247
+ )
248
+
249
+ # load clip
250
+ self.clip, _, _ = open_clip.create_model_and_transforms(
251
+ 'ViT-L-14', pretrained='openai')
252
+ for param in self.clip.parameters():
253
+ param.requires_grad = False
254
+ param.data = param.data.half()
255
+ self.clip.transformer = None
256
+
257
+ self.image_words = 30
258
+ self.cache_image_words = 0 # for inference
259
+
260
+ clip_width = self.clip.visual.conv1.out_channels
261
+ # create modal shared modules
262
+ self.resample_layers = nn.ModuleDict()
263
+ self.num_experts = 3
264
+ self.num_resample_layers = 8
265
+ for expert in range(self.num_experts):
266
+ expert = str(expert)
267
+ self.resample_layers[expert] = nn.ModuleList()
268
+ resampler_params = copy.deepcopy(params)
269
+ resampler_params.n_heads = 16
270
+ resampler_params.dim = clip_width
271
+ for layer_id in range(self.num_resample_layers):
272
+ self.resample_layers[expert].append(
273
+ TransformerBlock(layer_id, resampler_params))
274
+
275
+ self.conv1 = nn.ModuleDict()
276
+ self.positional_embedding = nn.ParameterDict()
277
+ self.resample_tokens = nn.ParameterDict()
278
+ self.clip_proj1 = nn.ModuleDict()
279
+ self.clip_proj2 = nn.ModuleDict()
280
+ self.routers = nn.ModuleDict()
281
+ self.start_tag = nn.ParameterDict()
282
+ self.end_tag = nn.ParameterDict()
283
+ # self.modals = ['image', 'audio', 'point', 'video', 'rgbd', 'rgbn', 'fmri', 'imu']
284
+ self.modals = ['image', 'audio', 'video', 'rgbd', 'rgbn', 'fmri', 'imu']
285
+ for modal in self.modals:
286
+ if modal in ['image', 'video', 'rgbn', 'rgbn']:
287
+ modal_tokens = 256 + 1
288
+ pass
289
+ elif modal == 'audio':
290
+ self.conv1[modal] = nn.Conv2d(
291
+ 1, clip_width, kernel_size=(16, 16), stride=(10, 10))
292
+ modal_tokens = 1212 + 1
293
+ self.positional_embedding[modal] = nn.Parameter(
294
+ torch.empty([modal_tokens, clip_width]))
295
+ nn.init.normal_(self.positional_embedding[modal], std=0.02)
296
+ elif modal == 'point':
297
+ from lib.point_utils import PointPatchEmbed
298
+ self.conv1[modal] = PointPatchEmbed(
299
+ in_channels=6, channels=clip_width)
300
+ modal_tokens = 1024 + 1
301
+ self.positional_embedding[modal] = nn.Parameter(
302
+ torch.empty([modal_tokens, clip_width]))
303
+ nn.init.normal_(self.positional_embedding[modal], std=0.02)
304
+ elif modal == 'fmri':
305
+ self.conv1[modal] = nn.Linear(15724, 8192)
306
+ self.positional_embedding[modal] = nn.Parameter(
307
+ torch.empty([8+1, clip_width]))
308
+ nn.init.normal_(self.positional_embedding[modal], std=0.02)
309
+ elif modal == 'imu':
310
+ self.conv1[modal] = nn.Conv1d(
311
+ in_channels=6, out_channels=clip_width, kernel_size=10, bias=False)
312
+ self.positional_embedding[modal] = nn.Parameter(
313
+ torch.empty([391+1, clip_width]))
314
+ nn.init.normal_(self.positional_embedding[modal], std=0.02)
315
+
316
+ self.routers[modal] = Mlp(
317
+ clip_width, clip_width * 4, self.num_experts)
318
+
319
+ self.resample_tokens[modal] = nn.Parameter(
320
+ torch.empty([1, 30, resampler_params.dim]))
321
+ nn.init.normal_(self.resample_tokens[modal], std=0.02)
322
+
323
+ self.clip_proj1[modal] = nn.Sequential(
324
+ nn.Linear(clip_width, resampler_params.dim),
325
+ nn.LayerNorm(resampler_params.dim))
326
+
327
+ self.clip_proj2[modal] = nn.Sequential(
328
+ nn.Linear(resampler_params.dim, params.dim),
329
+ nn.LayerNorm(params.dim))
330
+
331
+ self.start_tag[modal] = nn.Parameter(torch.rand(1, 1, params.dim))
332
+ self.end_tag[modal] = nn.Parameter(torch.rand(1, 1, params.dim))
333
+
334
+ # @torch.no_grad()
335
+
336
+ def clip_encode_image(self, x, modal='image'):
337
+ # shape = [*, width, grid ** 2]
338
+ x = x.reshape(x.shape[0], x.shape[1], -1)
339
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
340
+
341
+ x = torch.cat([self.clip.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1,
342
+ x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
343
+
344
+ # use pretrained pos embeding for rest modalities
345
+ pos_embedding = self.clip.visual.positional_embedding
346
+ if modal in ['audio', 'point', 'fmri', 'imu']:
347
+ pos_embedding = self.positional_embedding[modal]
348
+
349
+ x = x + pos_embedding.to(x.dtype)
350
+ x = self.clip.visual.ln_pre(x)
351
+
352
+ x = x.permute(1, 0, 2) # NLD -> LND
353
+ x = self.clip.visual.transformer(x)
354
+ x = x.permute(1, 0, 2) # LND -> NLD
355
+
356
+ # preserve all spatial tokens
357
+ x = self.clip.visual.ln_post(x[:, :, :])
358
+
359
+ # if self.clip.visual.proj is not None:
360
+ # x = x @ self.clip.visual.proj
361
+
362
+ return x
363
+
364
+ def encode_image(self, x, modal='image'):
365
+ bsz = x.size(0)
366
+ T = 1
367
+ if modal in ['image']:
368
+ # modified from CLIP
369
+ x = self.clip.visual.conv1(x) # shape = [*, width, grid, grid]
370
+ elif modal in ['audio', 'imu']:
371
+ x = self.conv1[modal](x)
372
+ elif modal == 'point':
373
+ # [B, 16384, 6] -> [B, 1024, 1024, 1]
374
+ x = self.conv1[modal](x.float()).to(x.dtype)
375
+ elif modal in ['video', 'rgbd', 'rgbn']:
376
+ # [B, 15, 3, 224, 224]
377
+ B, T = x.shape[:2]
378
+ bsz = B * T
379
+ x = x.reshape(bsz, *x.shape[2:])
380
+ x = self.clip.visual.conv1(x)
381
+ elif modal == 'fmri':
382
+ x = self.conv1[modal](x)
383
+ # [B, 1, 8196] -> [B, 1024, 8]
384
+ x = x.reshape(x.size(0), self.clip.visual.conv1.out_channels, -1)
385
+
386
+ image_feats = self.clip_encode_image(x, modal=modal)
387
+ # take mean on time dimension
388
+ # all inputs are reduced to [B, L, D]
389
+ bsz = int(bsz / T)
390
+ image_feats = image_feats.reshape(
391
+ bsz, T, *image_feats.shape[1:]).mean(dim=1)
392
+
393
+ image_feats = self.clip_proj1[modal](image_feats)
394
+ image_feats = torch.cat(
395
+ [self.resample_tokens[modal].repeat(bsz, 1, 1), image_feats], dim=1)
396
+
397
+ # routing modalites
398
+ # [B, L, D]->[B, L, N]
399
+ routing_weights = self.routers[modal](image_feats).sigmoid()
400
+ routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
401
+
402
+ image_feats_experts = []
403
+ for expert_id in range(self.num_experts):
404
+ image_feats_expert = image_feats
405
+ for layer in self.resample_layers[str(expert_id)]:
406
+ image_feats_expert = layer(image_feats_expert, 0, None, None)
407
+
408
+ image_feats_expert = image_feats_expert[:, :self.resample_tokens[modal].size(1)]
409
+ routing_weight = routing_weights[:, :self.resample_tokens[modal].size(
410
+ 1), expert_id]
411
+ # [B, L, D] * [B, L, 1]
412
+ image_feats_expert = image_feats_expert * routing_weight[:, :, None]
413
+
414
+ image_feats_experts.append(image_feats_expert)
415
+
416
+ image_feats = sum(image_feats_experts)
417
+ image_feats = self.clip_proj2[modal](image_feats)
418
+
419
+ return image_feats
420
+
421
+ def forward(self, examples, image=None, modal='image'):
422
+ self._destroy_kv_cache() # training always disables kv cache
423
+ modal = modal[0]
424
+ _bsz, seqlen = examples.shape
425
+ h = self.tok_embeddings(examples)
426
+ self.freqs_cis = self.freqs_cis.to(h.device)
427
+
428
+ start_pos = 0
429
+ prefix_len = 0
430
+ if image is not None:
431
+ h_bos, h_caption = h[:, :1], h[:, 1:]
432
+ image_tokens = self.encode_image(image, modal)
433
+ h = torch.cat((h_bos, self.start_tag[modal].expand(
434
+ _bsz, -1, -1), image_tokens, self.end_tag[modal].expand(_bsz, -1, -1), h_caption), dim=1)
435
+ # bos + image token + start_tag[modal], end_tag[modal] is used for caption generation
436
+ prefix_len = image_tokens.shape[1] + 1 + 1
437
+ seqlen = h.shape[1]
438
+
439
+ freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen]
440
+ mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=h.device)
441
+ mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
442
+ for layer in self.layers:
443
+ h = layer(h, start_pos, freqs_cis, mask)
444
+ h = self.norm(h)
445
+ output = self.output(h[:, prefix_len:, :])
446
+ return output
447
+
448
+ @torch.inference_mode()
449
+ def forward_inference(self, tokens: torch.Tensor, start_pos: int, image=None, modal='image'):
450
+ modal = modal[0] if isinstance(modal, list) else modal
451
+ _bsz, seqlen = tokens.shape
452
+ if start_pos == 0:
453
+ # kv cache will not re-allocate if size is unchanged
454
+ self._allocate_kv_cache(_bsz)
455
+ h = self.tok_embeddings(tokens)
456
+ self.freqs_cis = self.freqs_cis.to(h.device)
457
+
458
+ if image is not None:
459
+ h_bos, h_caption = h[:, :1], h[:, 1:]
460
+ image_tokens = self.encode_image(image, modal)
461
+ self.cache_image_words = image_tokens.shape[1]
462
+ h = torch.cat((h_bos, self.start_tag[modal].repeat(_bsz, 1, 1), image_tokens, self.end_tag[modal].repeat(_bsz, 1, 1), h_caption), dim=1)
463
+ seqlen = h.shape[1]
464
+ freqs_cis = self.freqs_cis[0: seqlen]
465
+ else:
466
+ if start_pos == 0:
467
+ self.cache_image_words = 0
468
+ freqs_cis = self.freqs_cis[0: seqlen]
469
+ else:
470
+ # if image was not None when start_pos=0,
471
+ # the offset should be added to start_pos within later forward_inference calls
472
+ start_pos = start_pos + self.cache_image_words
473
+ freqs_cis = self.freqs_cis[start_pos: start_pos + seqlen]
474
+
475
+ # freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
476
+
477
+ mask = None
478
+ if seqlen > 1:
479
+ mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)
480
+ mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
481
+
482
+ for layer in self.layers:
483
+ h = layer(h, start_pos, freqs_cis, mask)
484
+ h = self.norm(h)
485
+ output = self.output(h[:, -1, :]) # only compute last logits
486
+ return output.float()
487
+
488
+ def _allocate_kv_cache(self, max_batch_size: int) -> None:
489
+ for layer in self.layers:
490
+ layer.attention.allocate_kv_cache(
491
+ max_batch_size, self.params.max_seq_len)
492
+
493
+ def _destroy_kv_cache(self) -> None:
494
+ for layer in self.layers:
495
+ layer.attention.destroy_kv_cache()
model/__init__.py ADDED
File without changes
model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (154 Bytes). View file
 
model/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (152 Bytes). View file
 
model/__pycache__/components.cpython-39.pyc ADDED
Binary file (2.15 kB). View file
 
model/__pycache__/meta.cpython-310.pyc ADDED
Binary file (5.52 kB). View file
 
model/__pycache__/meta.cpython-39.pyc ADDED
Binary file (5.52 kB). View file