diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..03ff76df5665b3fa05b3be5a1699b1e0dd298d41
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,10 @@
+__pycache__
+*.pyc
+*.egg-info
+dist
+
+output
+output_dir
+*.pth
+*.log
+weights
\ No newline at end of file
diff --git a/README.md b/README.md
index e0405e1925d27b053e6581abc975e9c885d3335e..2c5b4b0ef87dfe708a2d0290369750f47f2c50be 100644
--- a/README.md
+++ b/README.md
@@ -1,15 +1,100 @@
---
-title: LLaMA Adapter V2
+title: OneLLM
emoji: 🚀
colorFrom: red
colorTo: indigo
sdk: gradio
-sdk_version: 3.23.0
+sdk_version: 4.7.1
app_file: app.py
pinned: false
---
-### LLaMA-Adapter
-The official demo for LLaMA-Adapter V2.
-Please refer to our [arXiv paper](https://arxiv.org/abs/2303.16199) and [github](https://github.com/ZrrSkywalker/LLaMA-Adapter) for more details.
+# OneLLM: One Framework to Align All Modalities with Language
+[[Project Page](https://onellm.csuhan.com)] [[Paper](#)] [[Web Demo](https://huggingface.co/spaces/csuhan/OneLLM)]
+
+Authors: [Jiaming Han](), [Kaixiong Gong](), [Yiyuan Zhang](), [Jiaqi Wang](), [Kaipeng Zhang](), [Dahua Lin](), [Yu Qiao](), [Peng Gao](), [Xiangyu Yue]().
+
+## News
+
+- **2023.12.01** Release model weights and inference code.
+
+## Contents
+
+- [Install](#install)
+- [Models](#models)
+- [Demo](#demo)
+
+
+
+
+
+### TODO
+
+- [ ] Data
+- [ ] Evaluation
+- [ ] Training
+
+### Install
+
+1. Clone the repo into a local folder.
+
+```bash
+git clone https://github.com/csuhan/OneLLM
+
+cd OneLLM
+```
+
+2. Install packages.
+
+```bash
+conda create -n onellm python=3.9 -y
+conda activate onellm
+
+pip install -r requirements.txt
+
+# install pointnet
+cd lib/pointnet2
+python setup.py install
+```
+
+3. Install Apex. (Optional)
+
+```bash
+git clone https://github.com/NVIDIA/apex
+cd apex
+pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
+```
+
+### Models
+
+We provide a preview model at: [csuhan/OneLLM-7B](https://huggingface.co/csuhan/OneLLM-7B).
+
+### Demo
+
+**Huggingface Demo:** [csuhan/OneLLM](https://huggingface.co/spaces/csuhan/OneLLM).
+
+**Local Demo:** Assume you have downloaded the weights to ${WEIGHTS_DIR}. Then run the following command to start a gradio demo locally.
+
+```bash
+python demos/multi_turn_mm.py --gpu_ids 0 --tokenizer_path config/llama2/tokenizer.model --llama_config config/llama2/7B.json --pretrained_path ${WEIGHTS_DIR}/consolidated.00-of-01.pth
+```
+
+
+
+
+
+## Citation
+
+```
+@article{han2023onellm,
+ title={OneLLM: One Framework to Align All Modalities with Language},
+ author={Han, Jiaming and Gong, Kaixiong and Zhang, Yiyuan and Wang, Jiaqi and Zhang, Kaipeng and Lin, Dahua and Qiao, Yu and Gao, Peng and Yue, Xiangyu},
+ journal={arXiv preprint arXiv:xxxx},
+ year={2023}
+}
+```
+
+## Acknowledgement
+
+[LLaMA](https://github.com/facebookresearch/llama), [LLaMA-Adapter](https://github.com/OpenGVLab/LLaMA-Adapter), [LLaMA2-Accessory](https://github.com/Alpha-VLLM/LLaMA2-Accessory), [Meta-Transformer](https://github.com/invictus717/MetaTransformer), [ChatBridge](https://github.com/joez17/ChatBridge)
diff --git a/app.py b/app.py
index 26c88289cdf2fe061461952331735abe6fa46172..a180cda697755716b352d8fa6456204db8aac801 100644
--- a/app.py
+++ b/app.py
@@ -1,277 +1,272 @@
-import json
-import os
-import glob
import sys
-import time
-from pathlib import Path
-from typing import Tuple
+import os
+
+import argparse
+import multiprocessing as mp
+import numpy as np
+from typing import List, Optional
-from huggingface_hub import hf_hub_download
-from PIL import Image
-import gradio as gr
import torch
-from fairscale.nn.model_parallel.initialize import initialize_model_parallel
-
-from llama import LLaMA, ModelArgs, Tokenizer, Transformer, VisionModel
-
-os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
-
-PROMPT_DICT = {
- "prompt_input": (
- "Below is an instruction that describes a task, paired with an input that provides further context. "
- "Write a response that appropriately completes the request.\n\n"
- "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
- ),
- "prompt_no_input": (
- "Below is an instruction that describes a task. "
- "Write a response that appropriately completes the request.\n\n"
- "### Instruction:\n{instruction}\n\n### Response:"
- ),
-}
-
-
-def setup_model_parallel() -> Tuple[int, int]:
- os.environ['RANK'] = '0'
- os.environ['WORLD_SIZE'] = '1'
- os.environ['MP'] = '1'
- os.environ['MASTER_ADDR'] = '127.0.0.1'
- os.environ['MASTER_PORT'] = '2223'
- local_rank = int(os.environ.get("LOCAL_RANK", -1))
- world_size = int(os.environ.get("WORLD_SIZE", -1))
-
- torch.distributed.init_process_group("nccl")
- initialize_model_parallel(world_size)
- torch.cuda.set_device(local_rank)
-
- # seed must be the same in all processes
- torch.manual_seed(1)
- return local_rank, world_size
-
-
-def load(
- ckpt0_path: str,
- ckpt1_path: str,
- param_path: str,
- tokenizer_path: str,
- instruct_adapter_path: str,
- caption_adapter_path: str,
- local_rank: int,
- world_size: int,
- max_seq_len: int,
- max_batch_size: int,
-) -> LLaMA:
- start_time = time.time()
- print("Loading")
- instruct_adapter_checkpoint = torch.load(
- instruct_adapter_path, map_location="cpu")
- caption_adapter_checkpoint = torch.load(
- caption_adapter_path, map_location="cpu")
- with open(param_path, "r") as f:
- params = json.loads(f.read())
-
- model_args: ModelArgs = ModelArgs(
- max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
- )
- model_args.adapter_layer = int(
- instruct_adapter_checkpoint['adapter_query.weight'].shape[0] / model_args.adapter_len)
- model_args.cap_adapter_layer = int(
- caption_adapter_checkpoint['cap_adapter_query.weight'].shape[0] / model_args.cap_adapter_len)
-
- tokenizer = Tokenizer(model_path=tokenizer_path)
- model_args.vocab_size = tokenizer.n_words
- torch.set_default_tensor_type(torch.cuda.HalfTensor)
- model = Transformer(model_args)
-
- # To reduce memory usuage
- ckpt0 = torch.load(ckpt0_path, map_location='cuda')
- model.load_state_dict(ckpt0, strict=False)
- del ckpt0
- torch.cuda.empty_cache()
-
- ckpt1 = torch.load(ckpt1_path, map_location='cuda')
- model.load_state_dict(ckpt1, strict=False)
- del ckpt1
- torch.cuda.empty_cache()
-
- vision_model = VisionModel(model_args)
-
- torch.set_default_tensor_type(torch.FloatTensor)
- model.load_state_dict(instruct_adapter_checkpoint, strict=False)
- model.load_state_dict(caption_adapter_checkpoint, strict=False)
- vision_model.load_state_dict(caption_adapter_checkpoint, strict=False)
-
- generator = LLaMA(model, tokenizer, vision_model)
- print(f"Loaded in {time.time() - start_time:.2f} seconds")
- return generator
-
-
-def instruct_generate(
- instruct: str,
- input: str = 'none',
- max_gen_len=512,
- temperature: float = 0.1,
- top_p: float = 0.75,
-):
- if input == 'none':
- prompt = PROMPT_DICT['prompt_no_input'].format_map(
- {'instruction': instruct, 'input': ''})
- else:
- prompt = PROMPT_DICT['prompt_input'].format_map(
- {'instruction': instruct, 'input': input})
-
- results = generator.generate(
- [prompt], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p
- )
- result = results[0].strip()
- print(result)
- return result
-
-
-def caption_generate(
- img: str,
- max_gen_len=512,
- temperature: float = 0.1,
- top_p: float = 0.75,
-):
- imgs = [Image.open(img).convert('RGB')]
- prompts = ["Generate caption of this image :",] * len(imgs)
-
- results = generator.generate(
- prompts, imgs=imgs, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p
- )
- result = results[0].strip()
- print(result)
- return result
-
-
-def download_llama_adapter(instruct_adapter_path, caption_adapter_path):
- if not os.path.exists(instruct_adapter_path):
- os.system(
- f"wget -q -O {instruct_adapter_path} https://github.com/ZrrSkywalker/LLaMA-Adapter/releases/download/v.1.0.0/llama_adapter_len10_layer30_release.pth")
-
- if not os.path.exists(caption_adapter_path):
- os.system(
- f"wget -q -O {caption_adapter_path} https://github.com/ZrrSkywalker/LLaMA-Adapter/releases/download/v.1.0.0/llama_adapter_len10_layer30_caption_vit_l.pth")
-
-
-# ckpt_path = "/data1/llma/7B/consolidated.00.pth"
-# param_path = "/data1/llma/7B/params.json"
-# tokenizer_path = "/data1/llma/tokenizer.model"
-ckpt0_path = hf_hub_download(
- repo_id="csuhan/llama_storage", filename="consolidated.00_part0.pth")
-ckpt1_path = hf_hub_download(
- repo_id="csuhan/llama_storage", filename="consolidated.00_part1.pth")
-param_path = hf_hub_download(
- repo_id="nyanko7/LLaMA-7B", filename="params.json")
-tokenizer_path = hf_hub_download(
- repo_id="nyanko7/LLaMA-7B", filename="tokenizer.model")
-instruct_adapter_path = "llama_adapter_len10_layer30_release.pth"
-caption_adapter_path = "llama_adapter_len10_layer30_caption_vit_l.pth"
-max_seq_len = 512
-max_batch_size = 1
-
-# download models
-# download_llama_adapter(instruct_adapter_path, caption_adapter_path)
-
-local_rank, world_size = setup_model_parallel()
-if local_rank > 0:
- sys.stdout = open(os.devnull, "w")
-
-generator = load(
- ckpt0_path, ckpt1_path, param_path, tokenizer_path, instruct_adapter_path, caption_adapter_path, local_rank, world_size, max_seq_len, max_batch_size
-)
-
-
-def create_instruct_demo():
- with gr.Blocks() as instruct_demo:
- with gr.Row():
- with gr.Column():
- instruction = gr.Textbox(lines=2, label="Instruction")
- input = gr.Textbox(
- lines=2, label="Context input", placeholder='none')
- max_len = gr.Slider(minimum=1, maximum=512,
- value=128, label="Max length")
- with gr.Accordion(label='Advanced options', open=False):
- temp = gr.Slider(minimum=0, maximum=1,
- value=0.1, label="Temperature")
- top_p = gr.Slider(minimum=0, maximum=1,
- value=0.75, label="Top p")
-
- run_botton = gr.Button("Run")
-
- with gr.Column():
- outputs = gr.Textbox(lines=10, label="Output")
-
- inputs = [instruction, input, max_len, temp, top_p]
-
- examples = [
- "Tell me about alpacas.",
- "Write a Python program that prints the first 10 Fibonacci numbers.",
- "Write a conversation between the sun and pluto.",
- "Write a theory to explain why cat never existed",
- ]
- examples = [
- [x, "none", 128, 0.1, 0.75]
- for x in examples]
-
- gr.Examples(
- examples=examples,
- inputs=inputs,
- outputs=outputs,
- fn=instruct_generate,
- cache_examples=os.getenv('SYSTEM') == 'spaces'
- )
- run_botton.click(fn=instruct_generate, inputs=inputs, outputs=outputs)
- return instruct_demo
+import torch.distributed as dist
+from fairscale.nn.model_parallel import initialize as fs_init
-def create_caption_demo():
- with gr.Blocks() as instruct_demo:
- with gr.Row():
- with gr.Column():
- img = gr.Image(label='Input', type='filepath')
- max_len = gr.Slider(minimum=1, maximum=512,
- value=64, label="Max length")
- with gr.Accordion(label='Advanced options', open=False):
- temp = gr.Slider(minimum=0, maximum=1,
- value=0.1, label="Temperature")
- top_p = gr.Slider(minimum=0, maximum=1,
- value=0.75, label="Top p")
-
- run_botton = gr.Button("Run")
-
- with gr.Column():
- outputs = gr.Textbox(lines=10, label="Output")
-
- inputs = [img, max_len, temp, top_p]
-
- examples = glob.glob("caption_demo/*.jpg")
- examples = [
- [x, 64, 0.1, 0.75]
- for x in examples]
-
- gr.Examples(
- examples=examples,
- inputs=inputs,
- outputs=outputs,
- fn=caption_generate,
- cache_examples=os.getenv('SYSTEM') == 'spaces'
- )
- run_botton.click(fn=caption_generate, inputs=inputs, outputs=outputs)
- return instruct_demo
+import gradio as gr
+from util.misc import setup_for_distributed
+from util.misc import default_tensor_type
+from model.meta import MetaModel
+from data.conversation_lib import conv_templates, SeparatorStyle
+from PIL import Image
+import torchvision.transforms as transforms
+from data.fintune_dataset import make_audio_features
+from data import video_utils
+from dataclasses import dataclass
+from huggingface_hub import hf_hub_download
+T_random_resized_crop = transforms.Compose([
+ transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=3,
+ antialias=None), # 3 is bicubic
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
+
+
+def load_audio(audio_path):
+ fbank = make_audio_features(audio_path, mel_bins=128)
+ fbank = fbank.transpose(0, 1)[None] #[1, 128, 1024]
+ return fbank
+
+def load_video(video_path):
+ video_feats = video_utils.load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5)
+ return video_feats[:, :, 0]
+
+
+def model_worker(
+ rank: int, args: argparse.Namespace, barrier: mp.Barrier,
+ request_queue: mp.Queue, response_queue: Optional[mp.Queue] = None,
+) -> None:
+ """
+ The worker function that manipulates the GPU to run the inference.
+ Exact n_gpu workers are started, with each one operating on a separate GPU.
+
+ Args:
+ rank (int): Distributed rank of the worker.
+ args (argparse.Namespace): All command line arguments.
+ barrier (multiprocessing.Barrier): A barrier used to delay the start
+ of Web UI to be after the start of the model.
+ """
+
+ world_size = len(args.gpu_ids)
+ gpu_id = args.gpu_ids[rank]
+ dist.init_process_group(
+ backend="nccl", rank=rank, world_size=world_size,
+ init_method=f"tcp://{args.master_addr}:{args.master_port}",
+ )
+ print(f"| distributed init on worker {rank}/{world_size}. "
+ f"using gpu: {gpu_id}")
+ fs_init.initialize_model_parallel(world_size)
+ torch.cuda.set_device(gpu_id)
-description = """
-# LLaMA-Adapter🚀
-The official demo for **LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention**.
-Please refer to our [arXiv paper](https://arxiv.org/abs/2303.16199) and [github](https://github.com/ZrrSkywalker/LLaMA-Adapter) for more details.
-"""
+ torch.manual_seed(1)
+ np.random.seed(1)
+
+ # set the print behavior.
+ setup_for_distributed(rank == 0)
+
+ target_dtype = {
+ "bf16": torch.bfloat16,
+ "fp16": torch.float16
+ }[args.dtype]
+ with default_tensor_type(dtype=target_dtype, device="cuda"):
+ model = MetaModel(args.llama_type, args.llama_config, tokenizer_path=args.tokenizer_path)
+ print("Loading pretrained weights ...")
+ checkpoint = torch.load(args.pretrained_path, map_location='cpu')
+ msg = model.load_state_dict(checkpoint, strict=False)
+ print("load result:\n", msg)
+ model.cuda()
+ model.eval()
+ print(f"Model = {str(model)}")
+
+ barrier.wait()
+
+ while True:
+ img_path, audio_path, video_path, chatbot, max_gen_len, temperature, top_p, modality = request_queue.get()
+ if 'image' in modality and img_path is not None:
+ image = Image.open(img_path).convert('RGB')
+ inputs = T_random_resized_crop(image)
+ elif 'video' in modality and video_path is not None:
+ inputs = load_video(video_path)
+ elif 'audio' in modality and audio_path is not None:
+ inputs = load_audio(audio_path)
+ else:
+ inputs = None
+
+ if inputs is not None:
+ inputs = inputs[None].cuda().to(target_dtype)
+
+ conv = conv_templates["v1"].copy()
+ for user, bot in chatbot:
+ conv.append_message(conv.roles[0], user)
+ conv.append_message(conv.roles[1], bot)
+
+ with torch.cuda.amp.autocast(dtype=target_dtype):
+ print(conv.get_prompt())
+ for stream_response in model.stream_generate(
+ conv.get_prompt(), inputs,
+ max_gen_len=max_gen_len, temperature=temperature, top_p=top_p,
+ modal = modality
+ ):
+ conv_sep = (
+ conv.sep
+ if conv.sep_style == SeparatorStyle.SINGLE
+ else conv.sep2
+ )
+ end_pos = stream_response["text"].find(conv_sep)
+ if end_pos != -1:
+ stream_response["text"] = (
+ stream_response['text'][:end_pos].rstrip() + "\n"
+ )
+ stream_response["end_of_content"] = True
+
+ # keep a few characters if not end_of_content to avoid sending
+ # part of conv_sep before all of it is generated.
+ if not stream_response["end_of_content"]:
+ if len(stream_response["text"]) < len(conv_sep):
+ continue
+ stream_response["text"] = (
+ stream_response["text"][:-len(conv_sep)]
+ )
+
+ if response_queue is not None:
+ response_queue.put(stream_response)
+
+ if stream_response["end_of_content"]:
+ break
+
+
+def gradio_worker(
+ request_queues: List[mp.Queue], response_queue: mp.Queue,
+ args: argparse.Namespace, barrier: mp.Barrier,
+) -> None:
+ """
+ The gradio worker is responsible for displaying the WebUI and relay the
+ requests to model workers. It should be launched only once.
+
+ Args:
+ request_queues (List[mp.Queue]): A list of request queues (one for
+ each model worker).
+ args (argparse.Namespace): All command line arguments.
+ barrier (multiprocessing.Barrier): A barrier used to delay the start
+ of Web UI to be after the start of the model.
+ """
+
+ def show_user_input(msg, chatbot):
+ return "", chatbot + [[msg, None]]
+
+ def stream_model_output(img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality):
+ for queue in request_queues:
+ queue.put((img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality))
+ while True:
+ content_piece = response_queue.get()
+ chatbot[-1][1] = content_piece["text"]
+ yield chatbot
+ if content_piece["end_of_content"]:
+ break
+
+ def undo(chatbot):
+ if len(chatbot) > 0:
+ chatbot = chatbot[:-1]
+ return chatbot
+
+ def clear():
+ chatbot = []
+ msg = ""
+ return chatbot, msg
+
+ CSS ="""
+ .contain { display: flex; flex-direction: column; }
+ #component-0 { height: 100%; }
+ #chatbot { flex-grow: 1; overflow: auto;}
+ """
+ with gr.Blocks(css=CSS) as demo:
+ gr.Markdown("## OneLLM: One Framework to Align All Modalities with Language")
+ with gr.Row(equal_height=True):
+ with gr.Column(scale=1):
+ img_path = gr.Image(label='Image Input', type='filepath')
+ video_path = gr.Video(label='Video Input')
+ audio_path = gr.Audio(label='Audio Input', type='filepath', sources=['upload'])
+ modality = gr.Radio(choices=['image', 'audio', 'video'], value='image', interactive=True, label='Input Modalities')
+
+ with gr.Column(scale=2):
+ chatbot = gr.Chatbot(elem_id="chatbot")
+ msg = gr.Textbox()
-with gr.Blocks(css='style.css') as demo:
- gr.Markdown(description)
- with gr.TabItem("Instruction-Following"):
- create_instruct_demo()
- with gr.TabItem("Image Captioning"):
- create_caption_demo()
+ with gr.Row():
+ submit_button = gr.Button("Submit", variant="primary")
+ undo_button = gr.Button("Undo")
+ clear_button = gr.ClearButton([chatbot, msg, img_path, audio_path, video_path, modality])
+ with gr.Row():
+ max_gen_len = gr.Slider(
+ minimum=1, maximum=args.model_max_seq_len // 2,
+ value=args.model_max_seq_len // 2, interactive=True,
+ label="Single-turn max response length",
+ )
+ gen_t = gr.Slider(
+ minimum=0, maximum=1, value=0.1, interactive=True,
+ label="Temperature",
+ )
+ top_p = gr.Slider(
+ minimum=0, maximum=1, value=0.75, interactive=True,
+ label="Top-p",
+ )
+ msg.submit(
+ show_user_input, [msg, chatbot], [msg, chatbot],
+ ).then(
+ stream_model_output, [img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot,
+ )
+ submit_button.click(
+ show_user_input, [msg, chatbot], [msg, chatbot],
+ ).then(
+ stream_model_output, [img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot,
+ )
+ undo_button.click(undo, chatbot, chatbot)
+ # img_path.change(clear, [], [chatbot, msg])
+ barrier.wait()
+ demo.queue(api_open=True).launch(share=True, max_threads=1)
+
+
+@dataclass
+class DemoConfig:
+ gpu_ids = [0]
+ tokenizer_path = "config/llama2/tokenizer.model"
+ llama_type = "onellm"
+ llama_config = "config/llama2/7B.json"
+ model_max_seq_len = 2048
+ # pretrained_path = "weights/7B_2048/consolidated.00-of-01.pth"
+ pretrained_path = hf_hub_download(repo_id="csuhan/OneLLM-7B", filename="consolidated.00-of-01.pth")
+ master_port = 23861
+ master_addr = "127.0.0.1"
+ dtype = "fp16"
+
+if __name__ == "__main__":
+ args = DemoConfig()
+ # using the default "fork" method messes up some imported libs (e.g.,
+ # pandas)
+ mp.set_start_method("spawn")
+
+ # setup the queues and start the model workers
+ request_queues = []
+ response_queue = mp.Queue()
+ worker_processes = []
+ barrier = mp.Barrier(len(args.gpu_ids) + 1)
+ for rank, gpu_id in enumerate(args.gpu_ids):
+ request_queue = mp.Queue()
+ rank_response_queue = response_queue if rank == 0 else None
+ process = mp.Process(
+ target=model_worker,
+ args=(rank, args, barrier, request_queue, rank_response_queue),
+ )
+ process.start()
+ worker_processes.append(process)
+ request_queues.append(request_queue)
-demo.queue(api_open=True, concurrency_count=1).launch()
+ gradio_worker(request_queues, response_queue, args, barrier)
diff --git a/config/llama2/7B.json b/config/llama2/7B.json
new file mode 100644
index 0000000000000000000000000000000000000000..6523f76675b50e9cf3a57d1fb135189abcffe1c7
--- /dev/null
+++ b/config/llama2/7B.json
@@ -0,0 +1 @@
+{"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-05, "vocab_size": -1}
diff --git a/config/llama2/tokenizer.model b/config/llama2/tokenizer.model
new file mode 100644
index 0000000000000000000000000000000000000000..6c00c742ce03c627d6cd5b795984876fa49fa899
--- /dev/null
+++ b/config/llama2/tokenizer.model
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
+size 499723
diff --git a/data/__pycache__/conversation_lib.cpython-310.pyc b/data/__pycache__/conversation_lib.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7104daf11059185efd723e40160c7debf191003b
Binary files /dev/null and b/data/__pycache__/conversation_lib.cpython-310.pyc differ
diff --git a/data/__pycache__/conversation_lib.cpython-39.pyc b/data/__pycache__/conversation_lib.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fdca3a64c8523a2b4439b3e0c894c64b2020f486
Binary files /dev/null and b/data/__pycache__/conversation_lib.cpython-39.pyc differ
diff --git a/data/__pycache__/fintune_dataset.cpython-310.pyc b/data/__pycache__/fintune_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2e4e08c37f7d6e25b19727d2847d34fc1fdd0c8e
Binary files /dev/null and b/data/__pycache__/fintune_dataset.cpython-310.pyc differ
diff --git a/data/__pycache__/fintune_dataset.cpython-39.pyc b/data/__pycache__/fintune_dataset.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..45989953b86dec0a2084a59127bc1028067b8640
Binary files /dev/null and b/data/__pycache__/fintune_dataset.cpython-39.pyc differ
diff --git a/data/__pycache__/imu_utils.cpython-310.pyc b/data/__pycache__/imu_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cae8e9e22e039ecc388eb3193589fa83b5c3847b
Binary files /dev/null and b/data/__pycache__/imu_utils.cpython-310.pyc differ
diff --git a/data/__pycache__/imu_utils.cpython-39.pyc b/data/__pycache__/imu_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..46ddf01f39b937cad76df5c6201b4f62b441694d
Binary files /dev/null and b/data/__pycache__/imu_utils.cpython-39.pyc differ
diff --git a/data/__pycache__/video_utils.cpython-310.pyc b/data/__pycache__/video_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4c007e0ea4f4c05fe867e0bc31077d4c6bc0fe79
Binary files /dev/null and b/data/__pycache__/video_utils.cpython-310.pyc differ
diff --git a/data/__pycache__/video_utils.cpython-39.pyc b/data/__pycache__/video_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..81b5c7d05231e85a9f4552385921740940514e39
Binary files /dev/null and b/data/__pycache__/video_utils.cpython-39.pyc differ
diff --git a/data/conversation_lib.py b/data/conversation_lib.py
new file mode 100644
index 0000000000000000000000000000000000000000..783fe0eb8f9dd425ec6c285e820f755d2e955a3b
--- /dev/null
+++ b/data/conversation_lib.py
@@ -0,0 +1,369 @@
+import dataclasses
+from enum import auto, Enum
+from typing import List, Tuple
+
+
+class SeparatorStyle(Enum):
+ """Different separator style."""
+ SINGLE = auto()
+ TWO = auto()
+ MPT = auto()
+
+
+@dataclasses.dataclass
+class Conversation:
+ """A class that keeps all conversation history."""
+ system: str
+ roles: List[str]
+ messages: List[List[str]]
+ offset: int
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
+ sep: str = "###"
+ sep2: str = None
+ version: str = "Unknown"
+
+ skip_next: bool = False
+
+ def get_prompt(self):
+ if self.sep_style == SeparatorStyle.SINGLE:
+ ret = self.system + '\n\n' + self.sep
+ for role, message in self.messages:
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + ": " + message + '\n' + self.sep
+ else:
+ ret += role + ":"
+ return ret
+ elif self.sep_style == SeparatorStyle.TWO:
+ seps = [self.sep, self.sep2]
+ ret = self.system + seps[0]
+ for i, (role, message) in enumerate(self.messages):
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + ": " + message + seps[i % 2]
+ else:
+ ret += role + ":"
+ return ret
+ if self.sep_style == SeparatorStyle.MPT:
+ ret = self.system + self.sep
+ for role, message in self.messages:
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + message + self.sep
+ else:
+ ret += role
+ return ret
+ else:
+ raise ValueError(f"Invalid style: {self.sep_style}")
+
+ def append_message(self, role, message):
+ self.messages.append([role, message])
+
+ def get_images(self, return_pil=False):
+ images = []
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
+ if i % 2 == 0:
+ if type(msg) is tuple:
+ import base64
+ from io import BytesIO
+ from PIL import Image
+ msg, image, image_process_mode = msg
+ if image_process_mode == "Pad":
+ def expand2square(pil_img, background_color=(122, 116, 104)):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+ image = expand2square(image)
+ elif image_process_mode == "Crop":
+ pass
+ elif image_process_mode == "Resize":
+ image = image.resize((224, 224))
+ else:
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
+ max_hw, min_hw = max(image.size), min(image.size)
+ aspect_ratio = max_hw / min_hw
+ max_len, min_len = 800, 400
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
+ longest_edge = int(shortest_edge * aspect_ratio)
+ W, H = image.size
+ if H > W:
+ H, W = longest_edge, shortest_edge
+ else:
+ H, W = shortest_edge, longest_edge
+ image = image.resize((W, H))
+ if return_pil:
+ images.append(image)
+ else:
+ buffered = BytesIO()
+ image.save(buffered, format="JPEG")
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
+ images.append(img_b64_str)
+ return images
+
+ def to_gradio_chatbot(self):
+ ret = []
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
+ if i % 2 == 0:
+ if type(msg) is tuple:
+ import base64
+ from io import BytesIO
+ msg, image, image_process_mode = msg
+ max_hw, min_hw = max(image.size), min(image.size)
+ aspect_ratio = max_hw / min_hw
+ max_len, min_len = 800, 400
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
+ longest_edge = int(shortest_edge * aspect_ratio)
+ W, H = image.size
+ if H > W:
+ H, W = longest_edge, shortest_edge
+ else:
+ H, W = shortest_edge, longest_edge
+ image = image.resize((W, H))
+ # image = image.resize((224, 224))
+ buffered = BytesIO()
+ image.save(buffered, format="JPEG")
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
+ img_str = f''
+ msg = msg.replace('', img_str)
+ ret.append([msg, None])
+ else:
+ ret[-1][-1] = msg
+ return ret
+
+ def copy(self):
+ return Conversation(
+ system=self.system,
+ roles=self.roles,
+ messages=[[x, y] for x, y in self.messages],
+ offset=self.offset,
+ sep_style=self.sep_style,
+ sep=self.sep,
+ sep2=self.sep2)
+
+ def dict(self):
+ if len(self.get_images()) > 0:
+ return {
+ "system": self.system,
+ "roles": self.roles,
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
+ "offset": self.offset,
+ "sep": self.sep,
+ "sep2": self.sep2,
+ }
+ return {
+ "system": self.system,
+ "roles": self.roles,
+ "messages": self.messages,
+ "offset": self.offset,
+ "sep": self.sep,
+ "sep2": self.sep2,
+ }
+
+
+conv_v1 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("Human", "Assistant"),
+ messages=(
+ ("Human", "Give three tips for staying healthy."),
+ ("Assistant",
+ "Sure, here are three tips for staying healthy:\n"
+ "1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
+ "It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
+ "and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or "
+ "75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening "
+ "activities at least two days per week.\n"
+ "2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, "
+ "vegetables, whole grains, lean proteins, and healthy fats can help support "
+ "your overall health. Try to limit your intake of processed and high-sugar foods, "
+ "and aim to drink plenty of water throughout the day.\n"
+ "3. Get enough sleep: Getting enough quality sleep is essential for your physical "
+ "and mental health. Adults should aim for seven to nine hours of sleep per night. "
+ "Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
+ "help improve the quality of your sleep.")
+ ),
+ offset=2,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+conv_v1_2 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("Human", "Assistant"),
+ messages=(),
+
+ # (
+ # ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
+ # ("Assistant",
+ # "Renewable energy sources are those that can be replenished naturally in a relatively "
+ # "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
+ # "Non-renewable energy sources, on the other hand, are finite and will eventually be "
+ # "depleted, such as coal, oil, and natural gas. Here are some key differences between "
+ # "renewable and non-renewable energy sources:\n"
+ # "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
+ # "energy sources are finite and will eventually run out.\n"
+ # "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
+ # "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
+ # "and other negative effects.\n"
+ # "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
+ # "have lower operational costs than non-renewable sources.\n"
+ # "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
+ # "locations than non-renewable sources.\n"
+ # "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
+ # "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
+ # "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
+ # "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
+ # )
+ offset = 2,
+ sep_style = SeparatorStyle.SINGLE,
+ sep = "###",
+ )
+
+conv_vicuna_v1_1 = Conversation(
+ system="A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
+ roles=("USER", "ASSISTANT"),
+ version="v1",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+)
+
+conv_mpt = Conversation(
+ system="""<|im_start|>system
+- You are a helpful language and vision assistant.
+- You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
+- You should follow the instructions carefully and explain your answers in detail.""",
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
+ version="mpt",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.MPT,
+ sep="<|im_end|>",
+)
+
+conv_mpt_text = Conversation(
+ system="""<|im_start|>system
+- You are a helpful assistant chatbot trained by MosaicML.
+- You answer questions.
+- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
+- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""",
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
+ version="mpt",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.MPT,
+ sep="<|im_end|>",
+)
+
+conv_bair_v1 = Conversation(
+ system="BEGINNING OF CONVERSATION:",
+ roles=("USER", "GPT"),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+)
+
+simple_conv = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("Human", "Assistant"),
+ messages=(
+ ("Human", "Hi!"),
+ ("Assistant", "Hi there! How can I help you today?")
+ ),
+ offset=2,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+simple_conv_multimodal = Conversation(
+ system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
+ "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
+ "Follow the instructions carefully and explain your answers in detail.",
+ roles=("Human", "Assistant"),
+ messages=(
+ ("Human", "Hi!"),
+ ("Assistant", "Hi there! How can I help you today?\n")
+ ),
+ offset=2,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+simple_conv_mpt_multimodal = Conversation(
+ system="""<|im_start|>system
+- You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab.
+- You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
+- You should follow the instructions carefully and explain your answers in detail.""",
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
+ version="mpt",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.MPT,
+ sep="<|im_end|>",
+)
+
+simple_conv_legacy = Conversation(
+ system="You are LLaVA, a large language model trained by UW Madison WAIV Lab."
+ "You are designed to assist human with a variety of tasks using natural language."
+ "Follow the instructions carefully.",
+ roles=("Human", "Assistant"),
+ messages=(
+ ("Human", "Hi!\n\n### Response:"),
+ ("Assistant", "Hi there! How can I help you today?\n")
+ ),
+ offset=2,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+conv_llava_v1 = Conversation(
+ system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
+ "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
+ "Follow the instructions carefully and explain your answers in detail.",
+ roles=("USER", "ASSISTANT"),
+ version="v1",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+)
+
+default_conversation = conv_v1_2
+conv_templates = {
+ "default": conv_v1_2,
+ "simple": simple_conv,
+ "simple_legacy": simple_conv_legacy,
+ "multimodal": simple_conv_multimodal,
+ "mpt_multimodal": simple_conv_mpt_multimodal,
+ "llava_v1": conv_llava_v1,
+
+ # fastchat
+ "v1": conv_v1_2,
+ "bair_v1": conv_bair_v1,
+ "vicuna_v1_1": conv_vicuna_v1_1,
+ "mpt": conv_mpt,
+ "mpt_text": conv_mpt_text,
+}
+
+if __name__ == "__main__":
+ print(default_conversation.get_prompt())
diff --git a/data/fintune_dataset.py b/data/fintune_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f787139d702fbc46ef5ca8189ade8f89a9c7df0
--- /dev/null
+++ b/data/fintune_dataset.py
@@ -0,0 +1,449 @@
+import warnings
+
+import torch
+import yaml
+from torch.utils.data import Dataset
+from PIL import Image
+import json
+from model.tokenizer import Tokenizer
+import os
+import torchvision.transforms as transforms
+import random
+import torchvision.transforms.functional as F
+import torchaudio
+from . import conversation_lib
+
+import numpy as np
+from . import video_utils
+from .imu_utils import get_imu_frames
+
+
+IGNORE_INDEX = -100
+
+DEFAULT_IMAGE_TOKEN = ""
+try:
+ from torchvision.transforms import InterpolationMode
+
+ BICUBIC = InterpolationMode.BICUBIC
+except ImportError:
+ BICUBIC = Image.BICUBIC
+
+T_random_resized_crop = transforms.Compose([
+ transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=BICUBIC,
+ antialias=None), # 3 is bicubic
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
+
+
+# image transform
+transform_img_train = transforms.Compose([
+ transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(
+ 0.75, 1.3333), interpolation=3, antialias=None), # 3 is bicubic
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
+
+
+class PairRandomResizedCrop(transforms.RandomResizedCrop):
+ def forward(self, imgs):
+ i, j, h, w = self.get_params(imgs[0], self.scale, self.ratio)
+ return [F.resized_crop(img, i, j, h, w, self.size, self.interpolation, antialias=self.antialias) for img in imgs]
+
+
+class PairToTensor(transforms.ToTensor):
+ def __call__(self, pics):
+ return [F.to_tensor(pic) for pic in pics]
+
+
+class PairNormalize(transforms.Normalize):
+ def forward(self, tensors):
+ return [F.normalize(tensor, self.mean, self.std, self.inplace) for tensor in tensors]
+
+
+transform_pairimg_train = transforms.Compose([
+ PairRandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(
+ 0.75, 1.3333), interpolation=3, antialias=None), # 3 is bicubic
+ PairToTensor(),
+ PairNormalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
+
+
+def pc_norm(pc):
+ """ pc: NxC, return NxC """
+ xyz = pc[:, :3]
+ other_feature = pc[:, 3:]
+
+ centroid = torch.mean(xyz, dim=0)
+ xyz = xyz - centroid
+ m = torch.max(torch.sqrt(torch.sum(xyz ** 2, dim=1)))
+ xyz = xyz / m
+
+ pc = torch.cat((xyz, other_feature), dim=1)
+ return pc
+
+
+def make_audio_features(wav_name, mel_bins=128, target_length=1024, aug=False):
+ waveform, sr = torchaudio.load(wav_name)
+ # assert sr == 16000, 'input audio sampling rate must be 16kHz'
+ if sr != 16000:
+ trans = torchaudio.transforms.Resample(sr, 16000)
+ waveform = trans(waveform)
+
+ waveform = waveform - waveform.mean()
+
+ fbank = torchaudio.compliance.kaldi.fbank(
+ waveform, htk_compat=True, sample_frequency=16000, use_energy=False,
+ window_type='hanning', num_mel_bins=mel_bins, dither=0.0, frame_shift=10)
+
+ n_frames = fbank.shape[0]
+
+ p = target_length - n_frames
+ if p > 0:
+ m = torch.nn.ZeroPad2d((0, 0, 0, p))
+ fbank = m(fbank)
+ elif p < 0:
+ fbank = fbank[0:target_length, :]
+
+ if aug:
+ freqm = torchaudio.transforms.FrequencyMasking(48)
+ timem = torchaudio.transforms.TimeMasking(192)
+ fbank = torch.transpose(fbank, 0, 1)
+ fbank = fbank.unsqueeze(0)
+ fbank = freqm(fbank)
+ fbank = timem(fbank)
+ fbank = fbank.squeeze(0)
+ fbank = torch.transpose(fbank, 0, 1)
+
+ fbank = (fbank - (-4.2677393)) / (4.5689974 * 2)
+ return fbank
+
+
+class ConversationGenerator:
+ def __init__(self, tokenizer):
+ self.tokenizer = tokenizer
+ self.header = f"{conversation_lib.default_conversation.system}\n\n"
+ self._probe_tokenizer_style()
+
+ def _probe_tokenizer_style(self):
+ """
+ Given a sentence, e.g. "My darling", some tokenizers will make the space a seperate token,
+ while some others will merge the space into the next word, forming a token representing " darling".
+ Knowing which style the tokenizer takes is necessary for correct ground-truth label masking.
+
+ """
+ probe = "Probe am I"
+ sentence1 = self.tokenizer.encode(conversation_lib.default_conversation.roles[1] + ": " + probe,
+ bos=False, eos=False)
+ sentence2 = self.tokenizer.encode(probe,
+ bos=False, eos=False)
+ if sentence1[-len(sentence2):] == sentence2:
+ self.space_before_to_predict = False
+ else:
+ sentence3 = self.tokenizer.encode(" " + probe,
+ bos=False, eos=False)
+ assert sentence1[-len(sentence3):] == sentence3
+ self.space_before_to_predict = True
+
+ def add_speaker_and_signal(self, source, get_conversation=True):
+ """Add speaker and start/end signal on each round."""
+ BEGIN_SIGNAL = "### "
+ END_SIGNAL = "\n"
+ conversation = self.header
+
+ to_predict_list = []
+
+ for sentence in source:
+ from_str = sentence["from"]
+ if from_str.lower() in ["human"]:
+ from_str = conversation_lib.default_conversation.roles[0]
+ elif from_str.lower() in ["gpt", "assistant"]:
+ from_str = conversation_lib.default_conversation.roles[1]
+ else:
+ raise ValueError(f"unknown dialog role: {from_str.lower()}")
+
+ value = sentence["value"]
+ if DEFAULT_IMAGE_TOKEN in value:
+ value = value.replace(DEFAULT_IMAGE_TOKEN, '').strip()
+
+ sentence_value = BEGIN_SIGNAL + from_str + ": " + value + END_SIGNAL
+
+ if from_str == conversation_lib.default_conversation.roles[1]:
+ to_predict_value = value + END_SIGNAL + "###"
+ if self.space_before_to_predict:
+ to_predict_value = " " + to_predict_value
+ to_predict_list.append(to_predict_value)
+
+ if get_conversation:
+ conversation = conversation + sentence_value
+
+ conversation = conversation + BEGIN_SIGNAL
+ return conversation, to_predict_list
+
+
+DATASETS = dict(
+ image=[
+ dict(path="datasets/InstructionTuning/image/llava_v1_5_mix665k_image.json", type='image'),
+ dict(path='datasets/InstructionTuning/image/cococap_train.json', type='image'),
+ dict(path="datasets/InstructionTuning/image/llava_v1_5_mix665k_text.json", type='text'),
+ ],
+ audio=[
+ dict(path="datasets/InstructionTuning/audio/audiocap_train.json", type='audio'),
+ dict(path="datasets/InstructionTuning/audio/audiocap_val.json", type='audio'),
+ dict(path="datasets/InstructionTuning/audio/audio_conversation.json", type='audio'),
+ ],
+ video=[
+ dict(path="datasets/InstructionTuning/video/msrvtt_cap_trainval.json", type='video'),
+ dict(path="datasets/InstructionTuning/video/msrvtt_cap_test.json", type='video'),
+ dict(path="datasets/InstructionTuning/video/msrvtt_vqa_train.json", type='video'),
+ dict(path="datasets/InstructionTuning/video/msrvtt_vqa_val.json", type='video'),
+ dict(path="datasets/InstructionTuning/video/msrvtt_vqa_test.json", type='video'),
+ dict(path="datasets/InstructionTuning/video/video_complex_reasoning_10k.json", type='video'),
+ dict(path="datasets/InstructionTuning/video/video_conversation_10k.json", type='video'),
+ dict(path="datasets/InstructionTuning/video/video_detail_10k.json", type='video'),
+ ],
+ point=[
+ dict(path="datasets/InstructionTuning/point/pointllm_70k_formated.json", type='point'),
+ ],
+ rgbd=[
+ dict(path="datasets/InstructionTuning/depth_normal/llava_instruct_50k_depth.json", type='rgbd'),
+ ],
+ rgbn=[
+ dict(path="datasets/InstructionTuning/depth_normal/llava_instruct_50k_normal.json", type='rgbn'),
+ ],
+ imu=[
+ dict(path="datasets/InstructionTuning/imu/imu_fixed_50k.json", type='imu'),
+ ],
+ fmri=[
+ dict(path="datasets/InstructionTuning/fmri/fmri_fixed.json", type='fmri'),
+ ],
+)
+IMU_PATH = "/mnt/petrelfs/share_data/hanjiaming/ego4d/v2/processed_imu/"
+
+
+class FinetuneDialogDataset(Dataset):
+ def __init__(self, dataset=['image'], transform=T_random_resized_crop, max_words=2048, image_words=30, tokenizer_path=None):
+ if isinstance(dataset, str):
+ dataset = [dataset]
+
+ self.dataset = dataset
+
+ group_ann = {}
+ for d in dataset:
+ for meta in DATASETS[d]:
+ meta_path, meta_type = meta['path'], meta['type']
+ meta_ext = os.path.splitext(meta_path)[-1]
+ if meta_ext == ".json":
+ with open(meta_path) as f:
+ meta_l = json.load(f)
+ # add data_type
+ # this is a temp solution
+ new_meta_l = []
+ for l in meta_l:
+ l['data_type'] = meta_type
+ new_meta_l.append(l)
+ meta_l = new_meta_l
+ elif meta_ext == ".jsonl":
+ meta_l = []
+ with open(meta_path) as f:
+ for i, line in enumerate(f):
+ try:
+ meta_l.append(json.loads(line))
+ except json.decoder.JSONDecodeError as e:
+ print(
+ f"Error decoding the following jsonl line ({i}):\n{line.rstrip()}", force=True)
+ raise e
+ else:
+ raise NotImplementedError(
+ f"Unknown meta file extension: \"{meta_ext}\". "
+ f"Currently, .json, .jsonl are supported. "
+ "If you are using a supported format, please set the file extension so that the proper parsing "
+ "routine can be called."
+ )
+ if meta_type not in group_ann:
+ group_ann[meta_type] = []
+ print(f"{meta_path}, type {meta_type}: len {len(meta_l)}")
+ group_ann[meta_type] += meta_l
+
+ # sort group_ann for higher efficiency (items in one global batch with similar length)
+ for meta_type, meta_l in group_ann.items():
+ meta_l.sort(key=lambda data_item: sum(
+ [len(_['value']) for _ in data_item['conversations']]))
+
+ self.group_ann = group_ann
+ self.ann = sum(list(self.group_ann.values()), start=[])
+
+ self.group_indices = {}
+ start_pos = 0
+ for meta_type, meta_l in self.group_ann.items():
+ self.group_indices[meta_type] = list(
+ range(start_pos, start_pos + len(meta_l)))
+ start_pos = start_pos + len(meta_l)
+
+ print(f"total length: {len(self)}")
+ self.transform = transform
+ print(f"transform:\n{self.transform}")
+ self.max_words = max_words
+ self.image_words = image_words
+ self.tokenizer = Tokenizer(model_path=tokenizer_path)
+ self.conversation_generator = ConversationGenerator(self.tokenizer)
+
+ self.load_funcs = dict(
+ image=self.load_image,
+ audio=self.load_audio,
+ video=self.load_video,
+ point=self.load_point,
+ rgbd=self.load_rgbx,
+ rgbn=self.load_rgbx,
+ imu=self.load_imu,
+ fmri=self.load_fmri
+ )
+
+ def __len__(self):
+ return len(self.ann)
+
+ def load_image(self, data):
+ filename = data['image']
+ image = Image.open(filename).convert('RGB')
+ image = self.transform(image)
+ return image
+
+ def load_audio(self, data):
+ audio_path = data['image']
+ fbank = make_audio_features(audio_path, mel_bins=128)
+ fbank = fbank.transpose(0, 1)[None] # [1, 128, 1024]
+ return fbank
+
+ def load_video(self, data):
+ video_path = data['image']
+ video_feats = video_utils.load_and_transform_video_data(
+ video_path, video_path, clip_duration=1, clips_per_video=5)
+ return video_feats[:, :, 0]
+
+ def load_point(self, data):
+ point_path = data['image']
+ point_feat = torch.load(point_path, map_location='cpu')
+ point_feat = point_feat.transpose(0, 1)
+ return point_feat
+
+ def load_rgbx(self, data):
+ image_path = data['image']
+ x_image_path = data['depth_image'] if 'depth_image' in data else data['normal_image']
+ image = Image.open(image_path).convert('RGB')
+ x_image = Image.open(x_image_path).convert('RGB')
+ x_image = x_image.resize(image.size[-2:])
+
+ image, x_image = transform_pairimg_train([image, x_image])
+ # [2, 3, H, W]
+ image = torch.stack([image, x_image], dim=0)
+ return image
+
+ def load_fmri(self, data):
+ fmri_path = data['image']
+ data = np.load(fmri_path)
+ data = data.mean(axis=0)
+ data = torch.tensor(data[None])
+ return data
+
+ def load_imu(self, data_dict):
+ uid = data_dict["video_uid"]
+ w_s = data_dict["window_start"]
+ w_e = data_dict["window_end"]
+
+ imu_data = get_imu_frames(
+ IMU_PATH, uid,
+ video_start_sec=w_s,
+ video_end_sec=w_e,
+ )
+ if imu_data is None:
+ raise ValueError
+ return imu_data['signal']
+
+ def __getitem__(self, index, expect_type=None):
+ if expect_type is None:
+ data_item = self.ann[index]
+ else:
+ # in case we want get data from specific data_type
+ data_item = self.group_ann[expect_type][index]
+
+ data_type = data_item['data_type']
+ if data_type != 'text':
+ if data_type in self.load_funcs:
+ try:
+ image = self.load_funcs[data_type](data_item)
+ if image == None:
+ raise ValueError('Data is None')
+ except:
+ print('Error', data_item)
+ rand_idx = random.randint(
+ 0, len(self.group_ann[data_type]))
+ return self.__getitem__(rand_idx, expect_type=data_type)
+ else:
+ raise ValueError(f'Does not support {data_type}')
+ else:
+ image = None
+ # warnings.warn("pure black image for examples without image")
+ # image = torch.zeros(3, 224, 224)
+
+ source = data_item["conversations"]
+ conversation, to_predict_values = self.conversation_generator.add_speaker_and_signal(
+ source)
+ if len(to_predict_values) == 0:
+ warnings.warn(
+ f"see dialog data with nothing to predict, data: {data_item}")
+ return self[index-1]
+
+ tokenzed_conversation = self.tokenizer.encode(
+ conversation, bos=True, eos=True)
+ labels = [IGNORE_INDEX for _ in tokenzed_conversation]
+
+ check_pos = 0
+ for value in to_predict_values:
+ tokenized_value = self.tokenizer.encode(
+ value, bos=False, eos=False)
+ value_pos = find_sublist(
+ tokenzed_conversation[check_pos:], tokenized_value) + check_pos
+ if value_pos == -1:
+ print(
+ "a sentence mismatches the corresponding piece in the conversation")
+ return self[index-1]
+ labels[value_pos:value_pos+len(tokenized_value)] = tokenized_value
+ assert labels[value_pos:value_pos+len(
+ tokenized_value)] == tokenzed_conversation[value_pos:value_pos+len(tokenized_value)]
+ check_pos = value_pos+len(tokenized_value)
+
+ input2 = torch.tensor(tokenzed_conversation, dtype=torch.int64)
+ labels = torch.tensor(labels, dtype=torch.int64)
+
+ if image is not None:
+ max_words = self.max_words - self.image_words
+ else:
+ max_words = self.max_words
+ padding = max_words - input2.shape[0]
+ if padding > 0:
+ input2 = torch.cat(
+ (input2, torch.zeros(padding, dtype=torch.int64) - 1))
+ labels = torch.cat(
+ (labels, torch.zeros(padding, dtype=torch.int64) - 1))
+ elif padding < 0:
+ input2 = input2[:max_words]
+ labels = labels[:max_words]
+
+ input2_mask = input2.ge(0)
+ label_mask = labels.ge(0)
+ input2[~input2_mask] = 0
+ labels[~label_mask] = 0
+ input2_mask = input2_mask.float()
+ label_mask = label_mask.float()
+ if image is None:
+ return input2, labels, data_item['data_type']
+ else:
+ return input2, labels, image, data_item['data_type']
+
+ def groups(self):
+ return list(self.group_indices.values())
+
+
+def find_sublist(a: list, b: list):
+ len_a, len_b = len(a), len(b)
+ for i in range(len_a - len_b + 1):
+ if a[i:i+len_b] == b:
+ return i
+ return -1
diff --git a/data/imu_utils.py b/data/imu_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..010563d67e603bd7ca5589672058380a79ee93d9
--- /dev/null
+++ b/data/imu_utils.py
@@ -0,0 +1,257 @@
+import string
+import numpy as np
+import matplotlib.animation as animation
+from matplotlib import pyplot as plt
+import json
+from collections import defaultdict
+from bisect import bisect_left
+import os
+import torch
+import torchaudio
+torchaudio.set_audio_backend("sox_io")
+
+
+def load_json(json_path: str):
+ """
+ Load a json file
+ """
+ with open(json_path, "r", encoding="utf-8") as f_name:
+ data = json.load(f_name)
+ return data
+
+
+def check_window_signal(info_t, w_s, w_e):
+ length = w_e - w_s
+ frame_offset = int(w_s * info_t.sample_rate)
+ num_frames = int(length * info_t.sample_rate)
+ if frame_offset + num_frames > int(info_t.num_frames):
+ return False
+ else:
+ return True
+
+
+def index_narrations(ann_path):
+ narration_raw = load_json(ann_path)
+
+ narration_dict = defaultdict(list)
+ summary_dict = defaultdict(list)
+ avg_len = []
+ for v_id, narr in narration_raw.items():
+ narr_list = []
+ summ_list = []
+ if "narration_pass_1" in narr:
+ narr_list += narr["narration_pass_1"]["narrations"]
+ summ_list += narr["narration_pass_1"]["summaries"]
+ if "narration_pass_2" in narr:
+ narr_list += narr["narration_pass_2"]["narrations"]
+ summ_list += narr["narration_pass_2"]["summaries"]
+
+ if len(narr_list) > 0:
+ narration_dict[v_id] = [
+ (
+ float(n_t["timestamp_sec"]),
+ n_t["narration_text"],
+ n_t["annotation_uid"],
+ n_t["timestamp_frame"],
+ )
+ for n_t in narr_list
+ ]
+ avg_len.append(len(narration_dict[v_id]))
+ else:
+ narration_dict[v_id] = []
+ if len(summ_list) > 0:
+ summary_dict[v_id] = [
+ (
+ float(s_t["start_sec"]),
+ float(s_t["end_sec"]),
+ s_t["summary_text"],
+ )
+ for s_t in summ_list
+ ]
+ else:
+ summary_dict[v_id] = []
+ # print(f"Number of Videos with narration {len(narration_dict)}")
+ # print(f"Avg. narration length {np.mean(avg_len)}")
+ # print(f"Number of Videos with summaries {len(summary_dict)}")
+ return narration_dict, summary_dict
+
+
+def get_signal_info(signal_fn: str):
+ return torchaudio.info(signal_fn)
+
+
+def get_signal_frames(signal_fn: str, video_start_sec: float, video_end_sec: float):
+ """
+ Given a signal track return the frames between video_start_sec and video_end_sec
+ """
+ info_t = get_signal_info(signal_fn)
+
+ length = video_end_sec - video_start_sec
+ aframes, _ = torchaudio.load(
+ signal_fn,
+ normalize=True,
+ frame_offset=int(video_start_sec * info_t.sample_rate),
+ num_frames=int(length * info_t.sample_rate),
+ )
+ return {"signal": aframes, "meta": info_t}
+
+
+def tosec(value):
+ return value / 1000
+
+
+def toms(value):
+ return value * 1000
+
+
+def delta(first_num: float, second_num: float):
+ """Compute the absolute value of the difference of two numbers"""
+ return abs(first_num - second_num)
+
+
+def padIMU(signal, duration_sec):
+ """
+ Pad the signal if necessary
+ """
+ expected_elements = round(duration_sec) * 200
+
+ if signal.shape[0] > expected_elements:
+ signal = signal[:expected_elements, :]
+ elif signal.shape[0] < expected_elements:
+ padding = expected_elements - signal.shape[0]
+ padded_zeros = np.zeros((padding, 6))
+ signal = np.concatenate([signal, padded_zeros], 0)
+ # signal = signal[:expected_elements, :]
+ return signal
+
+
+def resample(
+ signals: np.ndarray,
+ timestamps: np.ndarray,
+ original_sample_rate: int,
+ resample_rate: int,
+):
+ """
+ Resamples data to new sample rate
+ """
+ signals = torch.as_tensor(signals)
+ timestamps = torch.from_numpy(timestamps).unsqueeze(-1)
+ signals = torchaudio.functional.resample(
+ waveform=signals.data.T,
+ orig_freq=original_sample_rate,
+ new_freq=resample_rate,
+ ).T.numpy()
+
+ nsamples = len(signals)
+
+ period = 1 / resample_rate
+
+ # timestamps are expected to be shape (N, 1)
+ initital_seconds = timestamps[0] / 1e3
+
+ ntimes = (torch.arange(nsamples) * period).view(-1, 1) + initital_seconds
+
+ timestamps = (ntimes * 1e3).squeeze().numpy()
+ return signals, timestamps
+
+
+def resampleIMU(signal, timestamps):
+ sampling_rate = int(1000 * (1 / (np.mean(np.diff(timestamps)))))
+ # resample all to 200hz
+ if sampling_rate != 200:
+ signal, timestamps = resample(signal, timestamps, sampling_rate, 200)
+ return signal, timestamps
+
+
+def get_imu_frames(
+ imu_path,
+ uid: str,
+ video_start_sec: float,
+ video_end_sec: float,
+):
+ """
+ Given a IMU signal return the frames between video_start_sec and video_end_sec
+ """
+ signal = np.load(os.path.join(imu_path, f"{uid}.npy"))
+ signal = signal.transpose()
+ timestamps = np.load(os.path.join(imu_path, f"{uid}_timestamps.npy"))
+
+ if toms(video_start_sec) > timestamps[-1] or toms(video_end_sec) > timestamps[-1]:
+ return None
+
+ start_id = bisect_left(timestamps, toms(video_start_sec))
+ end_id = bisect_left(timestamps, toms(video_end_sec))
+
+ # make sure the retrieved window interval are correct by a max of 1 sec margin
+ if (
+ delta(video_start_sec, tosec(timestamps[start_id])) > 4
+ or delta(video_end_sec, tosec(timestamps[end_id])) > 4
+ ):
+ return None
+
+ # get the window
+ if start_id == end_id:
+ start_id -= 1
+ end_id += 1
+ signal, timestamps = signal[start_id:end_id], timestamps[start_id:end_id]
+
+ if len(signal) < 10 or len(timestamps) < 10:
+ return None
+ # resample the signal at 200hz if necessary
+ signal, timestamps = resampleIMU(signal, timestamps)
+
+ # pad the signal if necessary
+ signal = padIMU(signal, video_end_sec - video_start_sec)
+
+ sample_dict = {
+ "timestamp": timestamps,
+ "signal": torch.tensor(signal.T),
+ "sampling_rate": 200,
+ }
+
+ return sample_dict
+
+
+def display_animation(frames, title, save_path_gif):
+ fig, ax = plt.subplots()
+ frames = [[ax.imshow(frames[i])] for i in range(len(frames))]
+ plt.title(title)
+ ani = animation.ArtistAnimation(fig, frames)
+ ani.save(save_path_gif, writer="imagemagick")
+ plt.close()
+
+
+def display_animation_imu(frames, imu, title, save_path_gif):
+ fig, (ax1, ax2, ax3) = plt.subplots(3, 1)
+ ax1.set_title(title)
+ ax2.set_title("Acc.")
+ ax3.set_title("Gyro.")
+ frames = [[ax1.imshow(frames[i])] for i in range(len(frames))]
+ ani = animation.ArtistAnimation(fig, frames)
+
+ ax2.plot(imu[0].cpu().numpy(), color="red")
+ ax2.plot(imu[1].cpu().numpy(), color="blue")
+ ax2.plot(imu[2].cpu().numpy(), color="green")
+ ax3.plot(imu[3].cpu().numpy(), color="red")
+ ax3.plot(imu[4].cpu().numpy(), color="blue")
+ ax3.plot(imu[5].cpu().numpy(), color="green")
+ plt.tight_layout()
+ ani.save(save_path_gif, writer="imagemagick")
+ plt.close()
+
+
+def filter_narration(narration_text: str) -> bool:
+ if "#c" in narration_text.lower():
+ return True
+ return False
+
+
+def clean_narration_text(narration_text: str) -> str:
+ return (
+ narration_text.replace("#C C ", "")
+ .replace("#C", "")
+ .replace("#unsure", "something")
+ .strip()
+ .strip(string.punctuation)
+ .lower()[:128]
+ )
diff --git a/data/video_utils.py b/data/video_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..43ac03067e50c8570d422a057b9d9efb18e8775b
--- /dev/null
+++ b/data/video_utils.py
@@ -0,0 +1,204 @@
+import math
+import torch
+import torch.nn as nn
+from pytorchvideo import transforms as pv_transforms
+from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler
+from pytorchvideo.data.encoded_video import EncodedVideo
+from pytorchvideo.data.encoded_video_decord import EncodedVideoDecord
+from torchvision import transforms
+from torchvision.transforms._transforms_video import NormalizeVideo
+
+
+def get_clip_timepoints(clip_sampler, duration):
+ # Read out all clips in this video
+ all_clips_timepoints = []
+ is_last_clip = False
+ end = 0.0
+ while not is_last_clip:
+ start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
+ all_clips_timepoints.append((start, end))
+ return all_clips_timepoints
+
+
+
+def crop_boxes(boxes, x_offset, y_offset):
+ """
+ Perform crop on the bounding boxes given the offsets.
+ Args:
+ boxes (ndarray or None): bounding boxes to perform crop. The dimension
+ is `num boxes` x 4.
+ x_offset (int): cropping offset in the x axis.
+ y_offset (int): cropping offset in the y axis.
+ Returns:
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
+ `num boxes` x 4.
+ """
+ cropped_boxes = boxes.copy()
+ cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
+ cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
+
+ return cropped_boxes
+
+
+def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
+ """
+ Perform uniform spatial sampling on the images and corresponding boxes.
+ Args:
+ images (tensor): images to perform uniform crop. The dimension is
+ `num frames` x `channel` x `height` x `width`.
+ size (int): size of height and weight to crop the images.
+ spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
+ is larger than height. Or 0, 1, or 2 for top, center, and bottom
+ crop if height is larger than width.
+ boxes (ndarray or None): optional. Corresponding boxes to images.
+ Dimension is `num boxes` x 4.
+ scale_size (int): optinal. If not None, resize the images to scale_size before
+ performing any crop.
+ Returns:
+ cropped (tensor): images with dimension of
+ `num frames` x `channel` x `size` x `size`.
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
+ `num boxes` x 4.
+ """
+ assert spatial_idx in [0, 1, 2]
+ ndim = len(images.shape)
+ if ndim == 3:
+ images = images.unsqueeze(0)
+ height = images.shape[2]
+ width = images.shape[3]
+
+ if scale_size is not None:
+ if width <= height:
+ width, height = scale_size, int(height / width * scale_size)
+ else:
+ width, height = int(width / height * scale_size), scale_size
+ images = torch.nn.functional.interpolate(
+ images,
+ size=(height, width),
+ mode="bilinear",
+ align_corners=False,
+ )
+
+ y_offset = int(math.ceil((height - size) / 2))
+ x_offset = int(math.ceil((width - size) / 2))
+
+ if height > width:
+ if spatial_idx == 0:
+ y_offset = 0
+ elif spatial_idx == 2:
+ y_offset = height - size
+ else:
+ if spatial_idx == 0:
+ x_offset = 0
+ elif spatial_idx == 2:
+ x_offset = width - size
+ cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
+ cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
+ if ndim == 3:
+ cropped = cropped.squeeze(0)
+ return cropped, cropped_boxes
+
+
+class SpatialCrop(nn.Module):
+ """
+ Convert the video into 3 smaller clips spatially. Must be used after the
+ temporal crops to get spatial crops, and should be used with
+ -2 in the spatial crop at the slowfast augmentation stage (so full
+ frames are passed in here). Will return a larger list with the
+ 3x spatial crops as well.
+ """
+
+ def __init__(self, crop_size: int = 224, num_crops: int = 3):
+ super().__init__()
+ self.crop_size = crop_size
+ if num_crops == 3:
+ self.crops_to_ext = [0, 1, 2]
+ self.flipped_crops_to_ext = []
+ elif num_crops == 1:
+ self.crops_to_ext = [1]
+ self.flipped_crops_to_ext = []
+ else:
+ raise NotImplementedError("Nothing else supported yet")
+
+ def forward(self, videos):
+ """
+ Args:
+ videos: A list of C, T, H, W videos.
+ Returns:
+ videos: A list with 3x the number of elements. Each video converted
+ to C, T, H', W' by spatial cropping.
+ """
+ assert isinstance(videos, list), "Must be a list of videos after temporal crops"
+ assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)"
+ res = []
+ for video in videos:
+ for spatial_idx in self.crops_to_ext:
+ res.append(uniform_crop(video, self.crop_size, spatial_idx)[0])
+ if not self.flipped_crops_to_ext:
+ continue
+ flipped_video = transforms.functional.hflip(video)
+ for spatial_idx in self.flipped_crops_to_ext:
+ res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0])
+ return res
+
+
+def load_and_transform_video_data(
+ video_file,
+ video_path,
+ clip_duration=2,
+ clips_per_video=5,
+ sample_rate=16000,
+ with_audio=False
+):
+ video_transform = transforms.Compose(
+ [
+ pv_transforms.ShortSideScale(224),
+ NormalizeVideo(
+ mean=(0.48145466, 0.4578275, 0.40821073),
+ std=(0.26862954, 0.26130258, 0.27577711),
+ ),
+ ]
+ )
+
+ clip_sampler = ConstantClipsPerVideoSampler(
+ clip_duration=clip_duration, clips_per_video=clips_per_video
+ )
+ frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration)
+
+ if isinstance(video_file, str):
+ video = EncodedVideo.from_path(
+ video_file,
+ decoder="decord",
+ decode_audio=with_audio,
+ # **{"sample_rate": sample_rate},
+ )
+ else:
+ video = EncodedVideoDecord(video_file, video_name=video_path, decode_video=True, decode_audio=with_audio, sample_rate=sample_rate)
+
+ all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration)
+
+ all_video = []
+ for clip_timepoints in all_clips_timepoints:
+ # Read the clip, get frames
+ clip = video.get_clip(clip_timepoints[0], clip_timepoints[1])
+ if clip is None:
+ raise ValueError("No clip found")
+ video_clip = frame_sampler(clip["video"])
+ video_clip = video_clip / 255.0 # since this is float, need 0-1
+
+ all_video.append(video_clip)
+
+ all_video = [video_transform(clip) for clip in all_video]
+ all_video = SpatialCrop(224, num_crops=3)(all_video)
+
+ all_video = torch.stack(all_video, dim=0)
+
+ if not with_audio:
+ return all_video
+ else:
+ return all_video, clip['audio']
+
+if __name__ == '__main__':
+ video_path = "datasets/InstructionTuning/video/music_aqa/MUSIC-AVQA-videos-Real/00000002.mp4"
+ video, audio = load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5, with_audio=True)
+ import pdb;pdb.set_trace()
\ No newline at end of file
diff --git a/demos/multi_turn_mm.py b/demos/multi_turn_mm.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f354e6c68d0a09df50c87a1a53f110a4fe7321a
--- /dev/null
+++ b/demos/multi_turn_mm.py
@@ -0,0 +1,300 @@
+import sys
+import os
+sys.path.append(os.path.abspath(__file__).rsplit('/', 2)[0])
+
+import argparse
+import multiprocessing as mp
+import numpy as np
+from typing import List, Optional
+
+import torch
+import torch.distributed as dist
+
+from fairscale.nn.model_parallel import initialize as fs_init
+
+import gradio as gr
+from util.misc import setup_for_distributed
+from util.misc import default_tensor_type
+from model.meta import MetaModel
+from data.conversation_lib import conv_templates, SeparatorStyle
+from PIL import Image
+import torchvision.transforms as transforms
+from data.fintune_dataset import make_audio_features
+from data import video_utils
+
+
+T_random_resized_crop = transforms.Compose([
+ transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=3,
+ antialias=None), # 3 is bicubic
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
+
+
+def load_audio(audio_path):
+ fbank = make_audio_features(audio_path, mel_bins=128)
+ fbank = fbank.transpose(0, 1)[None] #[1, 128, 1024]
+ return fbank
+
+def load_video(video_path):
+ video_feats = video_utils.load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5)
+ return video_feats[:, :, 0]
+
+
+def model_worker(
+ rank: int, args: argparse.Namespace, barrier: mp.Barrier,
+ request_queue: mp.Queue, response_queue: Optional[mp.Queue] = None,
+) -> None:
+ """
+ The worker function that manipulates the GPU to run the inference.
+ Exact n_gpu workers are started, with each one operating on a separate GPU.
+
+ Args:
+ rank (int): Distributed rank of the worker.
+ args (argparse.Namespace): All command line arguments.
+ barrier (multiprocessing.Barrier): A barrier used to delay the start
+ of Web UI to be after the start of the model.
+ """
+
+ world_size = len(args.gpu_ids)
+ gpu_id = args.gpu_ids[rank]
+ dist.init_process_group(
+ backend="nccl", rank=rank, world_size=world_size,
+ init_method=f"tcp://{args.master_addr}:{args.master_port}",
+ )
+ print(f"| distributed init on worker {rank}/{world_size}. "
+ f"using gpu: {gpu_id}")
+ fs_init.initialize_model_parallel(world_size)
+ torch.cuda.set_device(gpu_id)
+
+ torch.manual_seed(1)
+ np.random.seed(1)
+
+ # set the print behavior.
+ setup_for_distributed(rank == 0)
+
+ target_dtype = {
+ "bf16": torch.bfloat16,
+ "fp16": torch.float16
+ }[args.dtype]
+ with default_tensor_type(dtype=target_dtype, device="cuda"):
+ model = MetaModel(args.llama_type, args.llama_config, tokenizer_path=args.tokenizer_path)
+ print("Loading pretrained weights ...")
+ checkpoint = torch.load(args.pretrained_path, map_location='cpu')
+ msg = model.load_state_dict(checkpoint, strict=False)
+ print("load result:\n", msg)
+ model.cuda()
+ model.eval()
+ print(f"Model = {str(model)}")
+
+ barrier.wait()
+
+ while True:
+ img_path, audio_path, video_path, chatbot, max_gen_len, temperature, top_p, modality = request_queue.get()
+ if 'image' in modality and img_path is not None:
+ image = Image.open(img_path).convert('RGB')
+ inputs = T_random_resized_crop(image)
+ elif 'video' in modality and video_path is not None:
+ inputs = load_video(video_path)
+ elif 'audio' in modality and audio_path is not None:
+ inputs = load_audio(audio_path)
+ else:
+ inputs = None
+
+ if inputs is not None:
+ inputs = inputs[None].cuda().to(target_dtype)
+
+ conv = conv_templates["v1"].copy()
+ for user, bot in chatbot:
+ conv.append_message(conv.roles[0], user)
+ conv.append_message(conv.roles[1], bot)
+
+ with torch.cuda.amp.autocast(dtype=target_dtype):
+ print(conv.get_prompt())
+ for stream_response in model.stream_generate(
+ conv.get_prompt(), inputs,
+ max_gen_len=max_gen_len, temperature=temperature, top_p=top_p,
+ modal = modality
+ ):
+ conv_sep = (
+ conv.sep
+ if conv.sep_style == SeparatorStyle.SINGLE
+ else conv.sep2
+ )
+ end_pos = stream_response["text"].find(conv_sep)
+ if end_pos != -1:
+ stream_response["text"] = (
+ stream_response['text'][:end_pos].rstrip() + "\n"
+ )
+ stream_response["end_of_content"] = True
+
+ # keep a few characters if not end_of_content to avoid sending
+ # part of conv_sep before all of it is generated.
+ if not stream_response["end_of_content"]:
+ if len(stream_response["text"]) < len(conv_sep):
+ continue
+ stream_response["text"] = (
+ stream_response["text"][:-len(conv_sep)]
+ )
+
+ if response_queue is not None:
+ response_queue.put(stream_response)
+
+ if stream_response["end_of_content"]:
+ break
+
+
+def gradio_worker(
+ request_queues: List[mp.Queue], response_queue: mp.Queue,
+ args: argparse.Namespace, barrier: mp.Barrier,
+) -> None:
+ """
+ The gradio worker is responsible for displaying the WebUI and relay the
+ requests to model workers. It should be launched only once.
+
+ Args:
+ request_queues (List[mp.Queue]): A list of request queues (one for
+ each model worker).
+ args (argparse.Namespace): All command line arguments.
+ barrier (multiprocessing.Barrier): A barrier used to delay the start
+ of Web UI to be after the start of the model.
+ """
+
+ def show_user_input(msg, chatbot):
+ return "", chatbot + [[msg, None]]
+
+ def stream_model_output(img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality):
+ for queue in request_queues:
+ queue.put((img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality))
+ while True:
+ content_piece = response_queue.get()
+ chatbot[-1][1] = content_piece["text"]
+ yield chatbot
+ if content_piece["end_of_content"]:
+ break
+
+ def undo(chatbot):
+ if len(chatbot) > 0:
+ chatbot = chatbot[:-1]
+ return chatbot
+
+ def clear():
+ chatbot = []
+ msg = ""
+ return chatbot, msg
+
+ CSS ="""
+ .contain { display: flex; flex-direction: column; }
+ #component-0 { height: 100%; }
+ #chatbot { flex-grow: 1; overflow: auto;}
+ """
+ with gr.Blocks(css=CSS) as demo:
+ gr.Markdown("## OneLLM: One Framework to Align All Modalities with Language")
+ with gr.Row(equal_height=True):
+ with gr.Column(scale=1):
+ img_path = gr.Image(label='Image Input', type='filepath')
+ video_path = gr.Video(label='Video Input')
+ audio_path = gr.Audio(label='Audio Input', type='filepath', sources=['upload'])
+ modality = gr.Radio(choices=['image', 'audio', 'video'], value='image', interactive=True, label='Input Modalities')
+
+ with gr.Column(scale=2):
+ chatbot = gr.Chatbot(elem_id="chatbot")
+ msg = gr.Textbox()
+
+ with gr.Row():
+ submit_button = gr.Button("Submit", variant="primary")
+ undo_button = gr.Button("Undo")
+ clear_button = gr.ClearButton([chatbot, msg, img_path, audio_path, video_path, modality])
+ with gr.Row():
+ max_gen_len = gr.Slider(
+ minimum=1, maximum=args.model_max_seq_len // 2,
+ value=args.model_max_seq_len // 2, interactive=True,
+ label="Single-turn max response length",
+ )
+ gen_t = gr.Slider(
+ minimum=0, maximum=1, value=0.1, interactive=True,
+ label="Temperature",
+ )
+ top_p = gr.Slider(
+ minimum=0, maximum=1, value=0.75, interactive=True,
+ label="Top-p",
+ )
+ msg.submit(
+ show_user_input, [msg, chatbot], [msg, chatbot],
+ ).then(
+ stream_model_output, [img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot,
+ )
+ submit_button.click(
+ show_user_input, [msg, chatbot], [msg, chatbot],
+ ).then(
+ stream_model_output, [img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot,
+ )
+ undo_button.click(undo, chatbot, chatbot)
+ # img_path.change(clear, [], [chatbot, msg])
+ barrier.wait()
+ demo.queue(api_open=True).launch(share=True, max_threads=1)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser("Chat Demo")
+ group = parser.add_mutually_exclusive_group()
+ group.add_argument(
+ "--gpu_ids", type=int, nargs="+",
+ help="A list of space-separated gpu ids to run the model on. "
+ "The model will span across GPUs in tensor-parallel mode."
+ )
+ parser.add_argument(
+ "--tokenizer_path", type=str,
+ help="Path to the tokenizer.model file provided along with the LLaMA "
+ "model."
+ )
+ parser.add_argument(
+ "--llama_type", default="onellm", type=str, metavar="MODEL",
+ help="LLaMA model type."
+ )
+ parser.add_argument(
+ "--llama_config", type=str, required=True,
+ help="Path to the llama model config json."
+ )
+ parser.add_argument(
+ "--model_max_seq_len", type=int, default=2048,
+ help="Max sequence length accepted by the pretrained model."
+ )
+ parser.add_argument(
+ "--pretrained_path", type=str, required=True,
+ help="Path to the llama model checkpoints. A list of checkpoints is "
+ "supported and will be merged from left to right.")
+ parser.add_argument(
+ "--master_port", type=int, default=23862,
+ help="A port used by the PyTorch distributed module to initialize."
+ )
+ parser.add_argument(
+ "--master_addr", type=str, default="127.0.0.1",
+ help="An address used by the PyTorch distributed module to initialize."
+ )
+ parser.add_argument(
+ "--dtype", type=str, choices=["fp16", "bf16"], default="fp16",
+ help="The dtype used for model weights and inference."
+ )
+ args = parser.parse_args()
+
+ # using the default "fork" method messes up some imported libs (e.g.,
+ # pandas)
+ mp.set_start_method("spawn")
+
+ # setup the queues and start the model workers
+ request_queues = []
+ response_queue = mp.Queue()
+ worker_processes = []
+ barrier = mp.Barrier(len(args.gpu_ids) + 1)
+ for rank, gpu_id in enumerate(args.gpu_ids):
+ request_queue = mp.Queue()
+ rank_response_queue = response_queue if rank == 0 else None
+ process = mp.Process(
+ target=model_worker,
+ args=(rank, args, barrier, request_queue, rank_response_queue),
+ )
+ process.start()
+ worker_processes.append(process)
+ request_queues.append(request_queue)
+
+ gradio_worker(request_queues, response_queue, args, barrier)
diff --git a/lib/__pycache__/point_utils.cpython-310.pyc b/lib/__pycache__/point_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b52bf4169d4d84233f3178c745896d1fa395824f
Binary files /dev/null and b/lib/__pycache__/point_utils.cpython-310.pyc differ
diff --git a/lib/point_utils.py b/lib/point_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..834733a64b540a141bfce09f6d0fae3154f89997
--- /dev/null
+++ b/lib/point_utils.py
@@ -0,0 +1,191 @@
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+import pointnet2_cuda
+
+class KNN(nn.Module):
+ def __init__(self, neighbors, transpose_mode=True):
+ super(KNN, self).__init__()
+ self.neighbors = neighbors
+
+ @torch.no_grad()
+ def forward(self, support, query):
+ """
+ Args:
+ support ([tensor]): [B, N, C]
+ query ([tensor]): [B, M, C]
+ Returns:
+ [int]: neighbor idx. [B, M, K]
+ """
+ dist = torch.cdist(support, query)
+ k_dist = dist.topk(k=self.neighbors, dim=1, largest=False)
+ return k_dist.values, k_dist.indices.transpose(1, 2).contiguous().int()
+
+
+class GroupingOperation(Function):
+
+ @staticmethod
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
+ def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
+ """
+ :param ctx:
+ :param features: (B, C, N) tensor of features to group
+ :param idx: (B, npoint, nsample) tensor containing the indicies of features to group with
+ :return:
+ output: (B, C, npoint, nsample) tensor
+ """
+ assert features.is_contiguous()
+ assert idx.is_contiguous()
+
+ B, nfeatures, nsample = idx.size()
+ _, C, N = features.size()
+ output = torch.cuda.FloatTensor(B, C, nfeatures, nsample, device=features.device)
+
+ pointnet2_cuda.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output)
+
+ ctx.for_backwards = (idx, N)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_out: torch.Tensor):
+ """
+ :param ctx:
+ :param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward
+ :return:
+ grad_features: (B, C, N) gradient of the features
+ """
+ idx, N = ctx.for_backwards
+
+ B, C, npoint, nsample = grad_out.size()
+ grad_features = torch.zeros([B, C, N], dtype=torch.float, device=grad_out.device, requires_grad=True)
+ grad_out_data = grad_out.data.contiguous()
+ pointnet2_cuda.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data)
+ return grad_features, None
+
+grouping_operation = GroupingOperation.apply
+
+
+class KNNGroup(nn.Module):
+ def __init__(self, nsample: int,
+ relative_xyz=True,
+ normalize_dp=False,
+ return_only_idx=False,
+ **kwargs
+ ):
+ """[summary]
+
+ Args:
+ nsample (int): maximum number of features to gather in the ball
+ use_xyz (bool, optional): concate xyz. Defaults to True.
+ ret_grouped_xyz (bool, optional): [description]. Defaults to False.
+ normalize_dp (bool, optional): [description]. Defaults to False.
+ """
+ super().__init__()
+ self.nsample = nsample
+ self.knn = KNN(nsample, transpose_mode=True)
+ self.relative_xyz = relative_xyz
+ self.normalize_dp = normalize_dp
+ self.return_only_idx = return_only_idx
+
+ def forward(self, query_xyz: torch.Tensor, support_xyz: torch.Tensor, features: torch.Tensor = None):
+ """
+ :param query_xyz: (B, N, 3) xyz coordinates of the features
+ :param support_xyz: (B, npoint, 3) centroids
+ :param features: (B, C, N) descriptors of the features
+ :return:
+ new_features: (B, 3 + C, npoint, nsample)
+ """
+ _, idx = self.knn(support_xyz, query_xyz)
+ if self.return_only_idx:
+ return idx
+ idx = idx.int()
+ xyz_trans = support_xyz.transpose(1, 2).contiguous()
+ grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
+ if self.relative_xyz:
+ grouped_xyz -= query_xyz.transpose(1, 2).unsqueeze(-1) # relative position
+ if self.normalize_dp:
+ grouped_xyz /= torch.amax(torch.sqrt(torch.sum(grouped_xyz**2, dim=1)), dim=(1, 2)).view(-1, 1, 1, 1)
+ if features is not None:
+ grouped_features = grouping_operation(features, idx)
+ return grouped_xyz, grouped_features
+ else:
+ return grouped_xyz, None
+
+
+class FurthestPointSampling(Function):
+ @staticmethod
+ def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
+ """
+ Uses iterative furthest point sampling to select a set of npoint features that have the largest
+ minimum distance
+ :param ctx:
+ :param xyz: (B, N, 3) where N > npoint
+ :param npoint: int, number of features in the sampled set
+ :return:
+ output: (B, npoint) tensor containing the set (idx)
+ """
+ assert xyz.is_contiguous()
+
+ B, N, _ = xyz.size()
+ # output = torch.cuda.IntTensor(B, npoint, device=xyz.device)
+ # temp = torch.cuda.FloatTensor(B, N, device=xyz.device).fill_(1e10)
+ output = torch.cuda.IntTensor(B, npoint)
+ temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
+
+ pointnet2_cuda.furthest_point_sampling_wrapper(
+ B, N, npoint, xyz, temp, output)
+ return output
+
+ @staticmethod
+ def backward(xyz, a=None):
+ return None, None
+
+furthest_point_sample = FurthestPointSampling.apply
+
+
+class PointPatchEmbed(nn.Module):
+
+ def __init__(self,
+ sample_ratio=0.0625,
+ sample_number=1024,
+ group_size=32,
+ in_channels=6,
+ channels=1024,
+ kernel_size=1,
+ stride=1,
+ normalize_dp=False,
+ relative_xyz=True,
+ ):
+ super().__init__()
+ self.sample_ratio = sample_ratio
+ self.sample_number = sample_number
+ self.group_size = group_size
+
+ self.sample_fn = furthest_point_sample
+ self.grouper = KNNGroup(self.group_size, relative_xyz=relative_xyz, normalize_dp=normalize_dp)
+
+ self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=kernel_size, stride=stride)
+
+
+ def forward(self, x):
+ # coordinates
+ p = x[:, :, 3:].contiguous()
+
+ B, N, _ = p.shape[:3]
+ # idx = self.sample_fn(p, int(N * self.sample_ratio)).long()
+ idx = self.sample_fn(p, self.sample_number).long()
+ center_p = torch.gather(p, 1, idx.unsqueeze(-1).expand(-1, -1, 3))
+ # query neighbors.
+ _, fj = self.grouper(center_p, p, x.permute(0, 2, 1).contiguous()) # [B, N, 6] -> [B, 6, N] -> [B, 6, 1024, 32]
+
+ # [B, 6, 1024] -> [B, channels, 1024, 1]
+ fj = self.conv1(fj).max(dim=-1, keepdim=True)[0]
+
+ return fj
+
+
+if __name__ == '__main__':
+ model = PointPatchEmbed(channels=256).cuda()
+ input = torch.rand(4, 16384, 6).cuda()
+ ou = model(input)
+ import pdb;pdb.set_trace()
\ No newline at end of file
diff --git a/lib/pointnet2/pointnet2_modules.py b/lib/pointnet2/pointnet2_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f125ce5075c738897e5f6a78c71123d0e3e44a2
--- /dev/null
+++ b/lib/pointnet2/pointnet2_modules.py
@@ -0,0 +1,160 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from . import pointnet2_utils
+from . import pytorch_utils as pt_utils
+from typing import List
+
+
+class _PointnetSAModuleBase(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.npoint = None
+ self.groupers = None
+ self.mlps = None
+ self.pool_method = 'max_pool'
+
+ def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, new_xyz=None) -> (torch.Tensor, torch.Tensor):
+ """
+ :param xyz: (B, N, 3) tensor of the xyz coordinates of the features
+ :param features: (B, N, C) tensor of the descriptors of the the features
+ :param new_xyz:
+ :return:
+ new_xyz: (B, npoint, 3) tensor of the new features' xyz
+ new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors
+ """
+ new_features_list = []
+
+ xyz_flipped = xyz.transpose(1, 2).contiguous()
+ if new_xyz is None:
+ new_xyz = pointnet2_utils.gather_operation(
+ xyz_flipped,
+ pointnet2_utils.furthest_point_sample(xyz, self.npoint)
+ ).transpose(1, 2).contiguous() if self.npoint is not None else None
+
+ for i in range(len(self.groupers)):
+ new_features = self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample)
+
+ new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
+ if self.pool_method == 'max_pool':
+ new_features = F.max_pool2d(
+ new_features, kernel_size=[1, new_features.size(3)]
+ ) # (B, mlp[-1], npoint, 1)
+ elif self.pool_method == 'avg_pool':
+ new_features = F.avg_pool2d(
+ new_features, kernel_size=[1, new_features.size(3)]
+ ) # (B, mlp[-1], npoint, 1)
+ else:
+ raise NotImplementedError
+
+ new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint)
+ new_features_list.append(new_features)
+
+ return new_xyz, torch.cat(new_features_list, dim=1)
+
+
+class PointnetSAModuleMSG(_PointnetSAModuleBase):
+ """Pointnet set abstraction layer with multiscale grouping"""
+
+ def __init__(self, *, npoint: int, radii: List[float], nsamples: List[int], mlps: List[List[int]], bn: bool = True,
+ use_xyz: bool = True, pool_method='max_pool', instance_norm=False):
+ """
+ :param npoint: int
+ :param radii: list of float, list of radii to group with
+ :param nsamples: list of int, number of samples in each ball query
+ :param mlps: list of list of int, spec of the pointnet before the global pooling for each scale
+ :param bn: whether to use batchnorm
+ :param use_xyz:
+ :param pool_method: max_pool / avg_pool
+ :param instance_norm: whether to use instance_norm
+ """
+ super().__init__()
+
+ assert len(radii) == len(nsamples) == len(mlps)
+
+ self.npoint = npoint
+ self.groupers = nn.ModuleList()
+ self.mlps = nn.ModuleList()
+ for i in range(len(radii)):
+ radius = radii[i]
+ nsample = nsamples[i]
+ self.groupers.append(
+ pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
+ if npoint is not None else pointnet2_utils.GroupAll(use_xyz)
+ )
+ mlp_spec = mlps[i]
+ if use_xyz:
+ mlp_spec[0] += 3
+
+ self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn, instance_norm=instance_norm))
+ self.pool_method = pool_method
+
+
+class PointnetSAModule(PointnetSAModuleMSG):
+ """Pointnet set abstraction layer"""
+
+ def __init__(self, *, mlp: List[int], npoint: int = None, radius: float = None, nsample: int = None,
+ bn: bool = True, use_xyz: bool = True, pool_method='max_pool', instance_norm=False):
+ """
+ :param mlp: list of int, spec of the pointnet before the global max_pool
+ :param npoint: int, number of features
+ :param radius: float, radius of ball
+ :param nsample: int, number of samples in the ball query
+ :param bn: whether to use batchnorm
+ :param use_xyz:
+ :param pool_method: max_pool / avg_pool
+ :param instance_norm: whether to use instance_norm
+ """
+ super().__init__(
+ mlps=[mlp], npoint=npoint, radii=[radius], nsamples=[nsample], bn=bn, use_xyz=use_xyz,
+ pool_method=pool_method, instance_norm=instance_norm
+ )
+
+
+class PointnetFPModule(nn.Module):
+ r"""Propigates the features of one set to another"""
+
+ def __init__(self, *, mlp: List[int], bn: bool = True):
+ """
+ :param mlp: list of int
+ :param bn: whether to use batchnorm
+ """
+ super().__init__()
+ self.mlp = pt_utils.SharedMLP(mlp, bn=bn)
+
+ def forward(
+ self, unknown: torch.Tensor, known: torch.Tensor, unknow_feats: torch.Tensor, known_feats: torch.Tensor
+ ) -> torch.Tensor:
+ """
+ :param unknown: (B, n, 3) tensor of the xyz positions of the unknown features
+ :param known: (B, m, 3) tensor of the xyz positions of the known features
+ :param unknow_feats: (B, C1, n) tensor of the features to be propigated to
+ :param known_feats: (B, C2, m) tensor of features to be propigated
+ :return:
+ new_features: (B, mlp[-1], n) tensor of the features of the unknown features
+ """
+ if known is not None:
+ dist, idx = pointnet2_utils.three_nn(unknown, known)
+ dist_recip = 1.0 / (dist + 1e-8)
+ norm = torch.sum(dist_recip, dim=2, keepdim=True)
+ weight = dist_recip / norm
+
+ interpolated_feats = pointnet2_utils.three_interpolate(known_feats, idx, weight)
+ else:
+ interpolated_feats = known_feats.expand(*known_feats.size()[0:2], unknown.size(1))
+
+ if unknow_feats is not None:
+ new_features = torch.cat([interpolated_feats, unknow_feats], dim=1) # (B, C2 + C1, n)
+ else:
+ new_features = interpolated_feats
+
+ new_features = new_features.unsqueeze(-1)
+ new_features = self.mlp(new_features)
+
+ return new_features.squeeze(-1)
+
+
+if __name__ == "__main__":
+ pass
diff --git a/lib/pointnet2/pointnet2_utils.py b/lib/pointnet2/pointnet2_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e814102d8feb5e443e64a736e7733818e0a24685
--- /dev/null
+++ b/lib/pointnet2/pointnet2_utils.py
@@ -0,0 +1,290 @@
+import torch
+from torch.autograd import Variable
+from torch.autograd import Function
+import torch.nn as nn
+from typing import Tuple
+
+import pointnet2_cuda as pointnet2
+
+
+class FurthestPointSampling(Function):
+ @staticmethod
+ def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
+ """
+ Uses iterative furthest point sampling to select a set of npoint features that have the largest
+ minimum distance
+ :param ctx:
+ :param xyz: (B, N, 3) where N > npoint
+ :param npoint: int, number of features in the sampled set
+ :return:
+ output: (B, npoint) tensor containing the set
+ """
+ assert xyz.is_contiguous()
+
+ B, N, _ = xyz.size()
+ output = torch.cuda.IntTensor(B, npoint)
+ temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
+
+ pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
+ return output
+
+ @staticmethod
+ def backward(xyz, a=None):
+ return None, None
+
+
+furthest_point_sample = FurthestPointSampling.apply
+
+
+class GatherOperation(Function):
+
+ @staticmethod
+ def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
+ """
+ :param ctx:
+ :param features: (B, C, N)
+ :param idx: (B, npoint) index tensor of the features to gather
+ :return:
+ output: (B, C, npoint)
+ """
+ assert features.is_contiguous()
+ assert idx.is_contiguous()
+
+ B, npoint = idx.size()
+ _, C, N = features.size()
+ output = torch.cuda.FloatTensor(B, C, npoint)
+
+ pointnet2.gather_points_wrapper(B, C, N, npoint, features, idx, output)
+
+ ctx.for_backwards = (idx, C, N)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ idx, C, N = ctx.for_backwards
+ B, npoint = idx.size()
+
+ grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
+ grad_out_data = grad_out.data.contiguous()
+ pointnet2.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data)
+ return grad_features, None
+
+
+gather_operation = GatherOperation.apply
+
+
+class ThreeNN(Function):
+
+ @staticmethod
+ def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Find the three nearest neighbors of unknown in known
+ :param ctx:
+ :param unknown: (B, N, 3)
+ :param known: (B, M, 3)
+ :return:
+ dist: (B, N, 3) l2 distance to the three nearest neighbors
+ idx: (B, N, 3) index of 3 nearest neighbors
+ """
+ assert unknown.is_contiguous()
+ assert known.is_contiguous()
+
+ B, N, _ = unknown.size()
+ m = known.size(1)
+ dist2 = torch.cuda.FloatTensor(B, N, 3)
+ idx = torch.cuda.IntTensor(B, N, 3)
+
+ pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx)
+ return torch.sqrt(dist2), idx
+
+ @staticmethod
+ def backward(ctx, a=None, b=None):
+ return None, None
+
+
+three_nn = ThreeNN.apply
+
+
+class ThreeInterpolate(Function):
+
+ @staticmethod
+ def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
+ """
+ Performs weight linear interpolation on 3 features
+ :param ctx:
+ :param features: (B, C, M) Features descriptors to be interpolated from
+ :param idx: (B, n, 3) three nearest neighbors of the target features in features
+ :param weight: (B, n, 3) weights
+ :return:
+ output: (B, C, N) tensor of the interpolated features
+ """
+ assert features.is_contiguous()
+ assert idx.is_contiguous()
+ assert weight.is_contiguous()
+
+ B, c, m = features.size()
+ n = idx.size(1)
+ ctx.three_interpolate_for_backward = (idx, weight, m)
+ output = torch.cuda.FloatTensor(B, c, n)
+
+ pointnet2.three_interpolate_wrapper(B, c, m, n, features, idx, weight, output)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ :param ctx:
+ :param grad_out: (B, C, N) tensor with gradients of outputs
+ :return:
+ grad_features: (B, C, M) tensor with gradients of features
+ None:
+ None:
+ """
+ idx, weight, m = ctx.three_interpolate_for_backward
+ B, c, n = grad_out.size()
+
+ grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_())
+ grad_out_data = grad_out.data.contiguous()
+
+ pointnet2.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data, idx, weight, grad_features.data)
+ return grad_features, None, None
+
+
+three_interpolate = ThreeInterpolate.apply
+
+
+class GroupingOperation(Function):
+
+ @staticmethod
+ def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
+ """
+ :param ctx:
+ :param features: (B, C, N) tensor of features to group
+ :param idx: (B, npoint, nsample) tensor containing the indicies of features to group with
+ :return:
+ output: (B, C, npoint, nsample) tensor
+ """
+ assert features.is_contiguous()
+ assert idx.is_contiguous()
+
+ B, nfeatures, nsample = idx.size()
+ _, C, N = features.size()
+ output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
+
+ pointnet2.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output)
+
+ ctx.for_backwards = (idx, N)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ :param ctx:
+ :param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward
+ :return:
+ grad_features: (B, C, N) gradient of the features
+ """
+ idx, N = ctx.for_backwards
+
+ B, C, npoint, nsample = grad_out.size()
+ grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
+
+ grad_out_data = grad_out.data.contiguous()
+ pointnet2.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data)
+ return grad_features, None
+
+
+grouping_operation = GroupingOperation.apply
+
+
+class BallQuery(Function):
+
+ @staticmethod
+ def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor:
+ """
+ :param ctx:
+ :param radius: float, radius of the balls
+ :param nsample: int, maximum number of features in the balls
+ :param xyz: (B, N, 3) xyz coordinates of the features
+ :param new_xyz: (B, npoint, 3) centers of the ball query
+ :return:
+ idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls
+ """
+ assert new_xyz.is_contiguous()
+ assert xyz.is_contiguous()
+
+ B, N, _ = xyz.size()
+ npoint = new_xyz.size(1)
+ idx = torch.cuda.IntTensor(B, npoint, nsample).zero_()
+
+ pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx)
+ return idx
+
+ @staticmethod
+ def backward(ctx, a=None):
+ return None, None, None, None
+
+
+ball_query = BallQuery.apply
+
+
+class QueryAndGroup(nn.Module):
+ def __init__(self, radius: float, nsample: int, use_xyz: bool = True):
+ """
+ :param radius: float, radius of ball
+ :param nsample: int, maximum number of features to gather in the ball
+ :param use_xyz:
+ """
+ super().__init__()
+ self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
+
+ def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None) -> Tuple[torch.Tensor]:
+ """
+ :param xyz: (B, N, 3) xyz coordinates of the features
+ :param new_xyz: (B, npoint, 3) centroids
+ :param features: (B, C, N) descriptors of the features
+ :return:
+ new_features: (B, 3 + C, npoint, nsample)
+ """
+ idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
+ xyz_trans = xyz.transpose(1, 2).contiguous()
+ grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
+ grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
+
+ if features is not None:
+ grouped_features = grouping_operation(features, idx)
+ if self.use_xyz:
+ new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, C + 3, npoint, nsample)
+ else:
+ new_features = grouped_features
+ else:
+ assert self.use_xyz, "Cannot have not features and not use xyz as a feature!"
+ new_features = grouped_xyz
+
+ return new_features
+
+
+class GroupAll(nn.Module):
+ def __init__(self, use_xyz: bool = True):
+ super().__init__()
+ self.use_xyz = use_xyz
+
+ def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None):
+ """
+ :param xyz: (B, N, 3) xyz coordinates of the features
+ :param new_xyz: ignored
+ :param features: (B, C, N) descriptors of the features
+ :return:
+ new_features: (B, C + 3, 1, N)
+ """
+ grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
+ if features is not None:
+ grouped_features = features.unsqueeze(2)
+ if self.use_xyz:
+ new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, 3 + C, 1, N)
+ else:
+ new_features = grouped_features
+ else:
+ new_features = grouped_xyz
+
+ return new_features
diff --git a/lib/pointnet2/pytorch_utils.py b/lib/pointnet2/pytorch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..09cb7bc76d88dde5757ac70b6e05e1e0c768cc1b
--- /dev/null
+++ b/lib/pointnet2/pytorch_utils.py
@@ -0,0 +1,236 @@
+import torch.nn as nn
+from typing import List, Tuple
+
+
+class SharedMLP(nn.Sequential):
+
+ def __init__(
+ self,
+ args: List[int],
+ *,
+ bn: bool = False,
+ activation=nn.ReLU(inplace=True),
+ preact: bool = False,
+ first: bool = False,
+ name: str = "",
+ instance_norm: bool = False,
+ ):
+ super().__init__()
+
+ for i in range(len(args) - 1):
+ self.add_module(
+ name + 'layer{}'.format(i),
+ Conv2d(
+ args[i],
+ args[i + 1],
+ bn=(not first or not preact or (i != 0)) and bn,
+ activation=activation
+ if (not first or not preact or (i != 0)) else None,
+ preact=preact,
+ instance_norm=instance_norm
+ )
+ )
+
+
+class _ConvBase(nn.Sequential):
+
+ def __init__(
+ self,
+ in_size,
+ out_size,
+ kernel_size,
+ stride,
+ padding,
+ activation,
+ bn,
+ init,
+ conv=None,
+ batch_norm=None,
+ bias=True,
+ preact=False,
+ name="",
+ instance_norm=False,
+ instance_norm_func=None
+ ):
+ super().__init__()
+
+ bias = bias and (not bn)
+ conv_unit = conv(
+ in_size,
+ out_size,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ bias=bias
+ )
+ init(conv_unit.weight)
+ if bias:
+ nn.init.constant_(conv_unit.bias, 0)
+
+ if bn:
+ if not preact:
+ bn_unit = batch_norm(out_size)
+ else:
+ bn_unit = batch_norm(in_size)
+ if instance_norm:
+ if not preact:
+ in_unit = instance_norm_func(out_size, affine=False, track_running_stats=False)
+ else:
+ in_unit = instance_norm_func(in_size, affine=False, track_running_stats=False)
+
+ if preact:
+ if bn:
+ self.add_module(name + 'bn', bn_unit)
+
+ if activation is not None:
+ self.add_module(name + 'activation', activation)
+
+ if not bn and instance_norm:
+ self.add_module(name + 'in', in_unit)
+
+ self.add_module(name + 'conv', conv_unit)
+
+ if not preact:
+ if bn:
+ self.add_module(name + 'bn', bn_unit)
+
+ if activation is not None:
+ self.add_module(name + 'activation', activation)
+
+ if not bn and instance_norm:
+ self.add_module(name + 'in', in_unit)
+
+
+class _BNBase(nn.Sequential):
+
+ def __init__(self, in_size, batch_norm=None, name=""):
+ super().__init__()
+ self.add_module(name + "bn", batch_norm(in_size))
+
+ nn.init.constant_(self[0].weight, 1.0)
+ nn.init.constant_(self[0].bias, 0)
+
+
+class BatchNorm1d(_BNBase):
+
+ def __init__(self, in_size: int, *, name: str = ""):
+ super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name)
+
+
+class BatchNorm2d(_BNBase):
+
+ def __init__(self, in_size: int, name: str = ""):
+ super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name)
+
+
+class Conv1d(_ConvBase):
+
+ def __init__(
+ self,
+ in_size: int,
+ out_size: int,
+ *,
+ kernel_size: int = 1,
+ stride: int = 1,
+ padding: int = 0,
+ activation=nn.ReLU(inplace=True),
+ bn: bool = False,
+ init=nn.init.kaiming_normal_,
+ bias: bool = True,
+ preact: bool = False,
+ name: str = "",
+ instance_norm=False
+ ):
+ super().__init__(
+ in_size,
+ out_size,
+ kernel_size,
+ stride,
+ padding,
+ activation,
+ bn,
+ init,
+ conv=nn.Conv1d,
+ batch_norm=BatchNorm1d,
+ bias=bias,
+ preact=preact,
+ name=name,
+ instance_norm=instance_norm,
+ instance_norm_func=nn.InstanceNorm1d
+ )
+
+
+class Conv2d(_ConvBase):
+
+ def __init__(
+ self,
+ in_size: int,
+ out_size: int,
+ *,
+ kernel_size: Tuple[int, int] = (1, 1),
+ stride: Tuple[int, int] = (1, 1),
+ padding: Tuple[int, int] = (0, 0),
+ activation=nn.ReLU(inplace=True),
+ bn: bool = False,
+ init=nn.init.kaiming_normal_,
+ bias: bool = True,
+ preact: bool = False,
+ name: str = "",
+ instance_norm=False
+ ):
+ super().__init__(
+ in_size,
+ out_size,
+ kernel_size,
+ stride,
+ padding,
+ activation,
+ bn,
+ init,
+ conv=nn.Conv2d,
+ batch_norm=BatchNorm2d,
+ bias=bias,
+ preact=preact,
+ name=name,
+ instance_norm=instance_norm,
+ instance_norm_func=nn.InstanceNorm2d
+ )
+
+
+class FC(nn.Sequential):
+
+ def __init__(
+ self,
+ in_size: int,
+ out_size: int,
+ *,
+ activation=nn.ReLU(inplace=True),
+ bn: bool = False,
+ init=None,
+ preact: bool = False,
+ name: str = ""
+ ):
+ super().__init__()
+
+ fc = nn.Linear(in_size, out_size, bias=not bn)
+ if init is not None:
+ init(fc.weight)
+ if not bn:
+ nn.init.constant(fc.bias, 0)
+
+ if preact:
+ if bn:
+ self.add_module(name + 'bn', BatchNorm1d(in_size))
+
+ if activation is not None:
+ self.add_module(name + 'activation', activation)
+
+ self.add_module(name + 'fc', fc)
+
+ if not preact:
+ if bn:
+ self.add_module(name + 'bn', BatchNorm1d(out_size))
+
+ if activation is not None:
+ self.add_module(name + 'activation', activation)
+
diff --git a/lib/pointnet2/setup.py b/lib/pointnet2/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..99e59e37b90517cc38c35d100f7f9cee0e309368
--- /dev/null
+++ b/lib/pointnet2/setup.py
@@ -0,0 +1,23 @@
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+setup(
+ name='pointnet2',
+ ext_modules=[
+ CUDAExtension('pointnet2_cuda', [
+ 'src/pointnet2_api.cpp',
+
+ 'src/ball_query.cpp',
+ 'src/ball_query_gpu.cu',
+ 'src/group_points.cpp',
+ 'src/group_points_gpu.cu',
+ 'src/interpolate.cpp',
+ 'src/interpolate_gpu.cu',
+ 'src/sampling.cpp',
+ 'src/sampling_gpu.cu',
+ ],
+ extra_compile_args={'cxx': ['-g'],
+ 'nvcc': ['-O2']})
+ ],
+ cmdclass={'build_ext': BuildExtension}
+)
diff --git a/lib/pointnet2/src/ball_query.cpp b/lib/pointnet2/src/ball_query.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..c9b176e5da5dd89a3378652f0b806925e8ee8996
--- /dev/null
+++ b/lib/pointnet2/src/ball_query.cpp
@@ -0,0 +1,24 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include "ball_query_gpu.h"
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
+#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
+
+int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample,
+ at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor) {
+ CHECK_INPUT(new_xyz_tensor);
+ CHECK_INPUT(xyz_tensor);
+ const float *new_xyz = new_xyz_tensor.data();
+ const float *xyz = xyz_tensor.data();
+ int *idx = idx_tensor.data();
+
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+ ball_query_kernel_launcher_fast(b, n, m, radius, nsample, new_xyz, xyz, idx, stream);
+ return 1;
+}
diff --git a/lib/pointnet2/src/ball_query_gpu.cu b/lib/pointnet2/src/ball_query_gpu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..f8840aa6650693cea17d337008a15fef13ec1ebc
--- /dev/null
+++ b/lib/pointnet2/src/ball_query_gpu.cu
@@ -0,0 +1,67 @@
+#include
+#include
+#include
+
+#include "ball_query_gpu.h"
+#include "cuda_utils.h"
+
+
+__global__ void ball_query_kernel_fast(int b, int n, int m, float radius, int nsample,
+ const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) {
+ // new_xyz: (B, M, 3)
+ // xyz: (B, N, 3)
+ // output:
+ // idx: (B, M, nsample)
+ int bs_idx = blockIdx.y;
+ int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (bs_idx >= b || pt_idx >= m) return;
+
+ new_xyz += bs_idx * m * 3 + pt_idx * 3;
+ xyz += bs_idx * n * 3;
+ idx += bs_idx * m * nsample + pt_idx * nsample;
+
+ float radius2 = radius * radius;
+ float new_x = new_xyz[0];
+ float new_y = new_xyz[1];
+ float new_z = new_xyz[2];
+
+ int cnt = 0;
+ for (int k = 0; k < n; ++k) {
+ float x = xyz[k * 3 + 0];
+ float y = xyz[k * 3 + 1];
+ float z = xyz[k * 3 + 2];
+ float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z);
+ if (d2 < radius2){
+ if (cnt == 0){
+ for (int l = 0; l < nsample; ++l) {
+ idx[l] = k;
+ }
+ }
+ idx[cnt] = k;
+ ++cnt;
+ if (cnt >= nsample) break;
+ }
+ }
+}
+
+
+void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, \
+ const float *new_xyz, const float *xyz, int *idx, cudaStream_t stream) {
+ // new_xyz: (B, M, 3)
+ // xyz: (B, N, 3)
+ // output:
+ // idx: (B, M, nsample)
+
+ cudaError_t err;
+
+ dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row)
+ dim3 threads(THREADS_PER_BLOCK);
+
+ ball_query_kernel_fast<<>>(b, n, m, radius, nsample, new_xyz, xyz, idx);
+ // cudaDeviceSynchronize(); // for using printf in kernel function
+ err = cudaGetLastError();
+ if (cudaSuccess != err) {
+ fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
+ exit(-1);
+ }
+}
\ No newline at end of file
diff --git a/lib/pointnet2/src/ball_query_gpu.h b/lib/pointnet2/src/ball_query_gpu.h
new file mode 100644
index 0000000000000000000000000000000000000000..ffc831a8b700f46b50e0b90d49c538aa0fedca50
--- /dev/null
+++ b/lib/pointnet2/src/ball_query_gpu.h
@@ -0,0 +1,15 @@
+#ifndef _BALL_QUERY_GPU_H
+#define _BALL_QUERY_GPU_H
+
+#include
+#include
+#include
+#include
+
+int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample,
+ at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor);
+
+void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample,
+ const float *xyz, const float *new_xyz, int *idx, cudaStream_t stream);
+
+#endif
diff --git a/lib/pointnet2/src/cuda_utils.h b/lib/pointnet2/src/cuda_utils.h
new file mode 100644
index 0000000000000000000000000000000000000000..7fe27969179c976a88199bbe962ca4f8d97263a4
--- /dev/null
+++ b/lib/pointnet2/src/cuda_utils.h
@@ -0,0 +1,15 @@
+#ifndef _CUDA_UTILS_H
+#define _CUDA_UTILS_H
+
+#include
+
+#define TOTAL_THREADS 1024
+#define THREADS_PER_BLOCK 256
+#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
+
+inline int opt_n_threads(int work_size) {
+ const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0);
+
+ return max(min(1 << pow_2, TOTAL_THREADS), 1);
+}
+#endif
diff --git a/lib/pointnet2/src/group_points.cpp b/lib/pointnet2/src/group_points.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..fa80f0e318acc57dabf76ec0a8b1d9dff482ab89
--- /dev/null
+++ b/lib/pointnet2/src/group_points.cpp
@@ -0,0 +1,34 @@
+#include
+#include
+#include
+#include
+#include "group_points_gpu.h"
+#include
+#include
+
+
+
+int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample,
+ at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) {
+
+ float *grad_points = grad_points_tensor.data();
+ const int *idx = idx_tensor.data();
+ const float *grad_out = grad_out_tensor.data();
+
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+ group_points_grad_kernel_launcher_fast(b, c, n, npoints, nsample, grad_out, idx, grad_points, stream);
+ return 1;
+}
+
+
+int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample,
+ at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor) {
+
+ const float *points = points_tensor.data();
+ const int *idx = idx_tensor.data();
+ float *out = out_tensor.data();
+
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+ group_points_kernel_launcher_fast(b, c, n, npoints, nsample, points, idx, out, stream);
+ return 1;
+}
diff --git a/lib/pointnet2/src/group_points_gpu.cu b/lib/pointnet2/src/group_points_gpu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..c015a8125e38aafa1f960000044978463b7853b1
--- /dev/null
+++ b/lib/pointnet2/src/group_points_gpu.cu
@@ -0,0 +1,86 @@
+#include
+#include
+
+#include "cuda_utils.h"
+#include "group_points_gpu.h"
+
+
+__global__ void group_points_grad_kernel_fast(int b, int c, int n, int npoints, int nsample,
+ const float *__restrict__ grad_out, const int *__restrict__ idx, float *__restrict__ grad_points) {
+ // grad_out: (B, C, npoints, nsample)
+ // idx: (B, npoints, nsample)
+ // output:
+ // grad_points: (B, C, N)
+ int bs_idx = blockIdx.z;
+ int c_idx = blockIdx.y;
+ int index = blockIdx.x * blockDim.x + threadIdx.x;
+ int pt_idx = index / nsample;
+ if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;
+
+ int sample_idx = index % nsample;
+ grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx;
+ idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx;
+
+ atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0] , grad_out[0]);
+}
+
+void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample,
+ const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) {
+ // grad_out: (B, C, npoints, nsample)
+ // idx: (B, npoints, nsample)
+ // output:
+ // grad_points: (B, C, N)
+ cudaError_t err;
+ dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
+ dim3 threads(THREADS_PER_BLOCK);
+
+ group_points_grad_kernel_fast<<>>(b, c, n, npoints, nsample, grad_out, idx, grad_points);
+
+ err = cudaGetLastError();
+ if (cudaSuccess != err) {
+ fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
+ exit(-1);
+ }
+}
+
+
+__global__ void group_points_kernel_fast(int b, int c, int n, int npoints, int nsample,
+ const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) {
+ // points: (B, C, N)
+ // idx: (B, npoints, nsample)
+ // output:
+ // out: (B, C, npoints, nsample)
+ int bs_idx = blockIdx.z;
+ int c_idx = blockIdx.y;
+ int index = blockIdx.x * blockDim.x + threadIdx.x;
+ int pt_idx = index / nsample;
+ if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;
+
+ int sample_idx = index % nsample;
+
+ idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx;
+ int in_idx = bs_idx * c * n + c_idx * n + idx[0];
+ int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx;
+
+ out[out_idx] = points[in_idx];
+}
+
+
+void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample,
+ const float *points, const int *idx, float *out, cudaStream_t stream) {
+ // points: (B, C, N)
+ // idx: (B, npoints, nsample)
+ // output:
+ // out: (B, C, npoints, nsample)
+ cudaError_t err;
+ dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
+ dim3 threads(THREADS_PER_BLOCK);
+
+ group_points_kernel_fast<<>>(b, c, n, npoints, nsample, points, idx, out);
+ // cudaDeviceSynchronize(); // for using printf in kernel function
+ err = cudaGetLastError();
+ if (cudaSuccess != err) {
+ fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
+ exit(-1);
+ }
+}
diff --git a/lib/pointnet2/src/group_points_gpu.h b/lib/pointnet2/src/group_points_gpu.h
new file mode 100644
index 0000000000000000000000000000000000000000..76c73ca2600ef75c192b06d28f79a168f1ba368b
--- /dev/null
+++ b/lib/pointnet2/src/group_points_gpu.h
@@ -0,0 +1,22 @@
+#ifndef _GROUP_POINTS_GPU_H
+#define _GROUP_POINTS_GPU_H
+
+#include
+#include
+#include
+#include
+
+
+int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample,
+ at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor);
+
+void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample,
+ const float *points, const int *idx, float *out, cudaStream_t stream);
+
+int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample,
+ at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor);
+
+void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample,
+ const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream);
+
+#endif
diff --git a/lib/pointnet2/src/interpolate.cpp b/lib/pointnet2/src/interpolate.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..88d837f966f52696308b7d85ec1756b2395bb986
--- /dev/null
+++ b/lib/pointnet2/src/interpolate.cpp
@@ -0,0 +1,53 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include "interpolate_gpu.h"
+#include
+#include
+
+
+void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor,
+ at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) {
+ const float *unknown = unknown_tensor.data();
+ const float *known = known_tensor.data();
+ float *dist2 = dist2_tensor.data();
+ int *idx = idx_tensor.data();
+
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+ three_nn_kernel_launcher_fast(b, n, m, unknown, known, dist2, idx, stream);
+}
+
+
+void three_interpolate_wrapper_fast(int b, int c, int m, int n,
+ at::Tensor points_tensor,
+ at::Tensor idx_tensor,
+ at::Tensor weight_tensor,
+ at::Tensor out_tensor) {
+
+ const float *points = points_tensor.data();
+ const float *weight = weight_tensor.data();
+ float *out = out_tensor.data();
+ const int *idx = idx_tensor.data();
+
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+ three_interpolate_kernel_launcher_fast(b, c, m, n, points, idx, weight, out, stream);
+}
+
+void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m,
+ at::Tensor grad_out_tensor,
+ at::Tensor idx_tensor,
+ at::Tensor weight_tensor,
+ at::Tensor grad_points_tensor) {
+
+ const float *grad_out = grad_out_tensor.data();
+ const float *weight = weight_tensor.data();
+ float *grad_points = grad_points_tensor.data();
+ const int *idx = idx_tensor.data();
+
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+ three_interpolate_grad_kernel_launcher_fast(b, c, n, m, grad_out, idx, weight, grad_points, stream);
+}
diff --git a/lib/pointnet2/src/interpolate_gpu.cu b/lib/pointnet2/src/interpolate_gpu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..a123dd8d8d4f5ed23cc4a340abb1141d140fca3c
--- /dev/null
+++ b/lib/pointnet2/src/interpolate_gpu.cu
@@ -0,0 +1,161 @@
+#include
+#include
+#include
+
+#include "cuda_utils.h"
+#include "interpolate_gpu.h"
+
+
+__global__ void three_nn_kernel_fast(int b, int n, int m, const float *__restrict__ unknown,
+ const float *__restrict__ known, float *__restrict__ dist2, int *__restrict__ idx) {
+ // unknown: (B, N, 3)
+ // known: (B, M, 3)
+ // output:
+ // dist2: (B, N, 3)
+ // idx: (B, N, 3)
+
+ int bs_idx = blockIdx.y;
+ int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (bs_idx >= b || pt_idx >= n) return;
+
+ unknown += bs_idx * n * 3 + pt_idx * 3;
+ known += bs_idx * m * 3;
+ dist2 += bs_idx * n * 3 + pt_idx * 3;
+ idx += bs_idx * n * 3 + pt_idx * 3;
+
+ float ux = unknown[0];
+ float uy = unknown[1];
+ float uz = unknown[2];
+
+ double best1 = 1e40, best2 = 1e40, best3 = 1e40;
+ int besti1 = 0, besti2 = 0, besti3 = 0;
+ for (int k = 0; k < m; ++k) {
+ float x = known[k * 3 + 0];
+ float y = known[k * 3 + 1];
+ float z = known[k * 3 + 2];
+ float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
+ if (d < best1) {
+ best3 = best2; besti3 = besti2;
+ best2 = best1; besti2 = besti1;
+ best1 = d; besti1 = k;
+ }
+ else if (d < best2) {
+ best3 = best2; besti3 = besti2;
+ best2 = d; besti2 = k;
+ }
+ else if (d < best3) {
+ best3 = d; besti3 = k;
+ }
+ }
+ dist2[0] = best1; dist2[1] = best2; dist2[2] = best3;
+ idx[0] = besti1; idx[1] = besti2; idx[2] = besti3;
+}
+
+
+void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown,
+ const float *known, float *dist2, int *idx, cudaStream_t stream) {
+ // unknown: (B, N, 3)
+ // known: (B, M, 3)
+ // output:
+ // dist2: (B, N, 3)
+ // idx: (B, N, 3)
+
+ cudaError_t err;
+ dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row)
+ dim3 threads(THREADS_PER_BLOCK);
+
+ three_nn_kernel_fast<<>>(b, n, m, unknown, known, dist2, idx);
+
+ err = cudaGetLastError();
+ if (cudaSuccess != err) {
+ fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
+ exit(-1);
+ }
+}
+
+
+__global__ void three_interpolate_kernel_fast(int b, int c, int m, int n, const float *__restrict__ points,
+ const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ out) {
+ // points: (B, C, M)
+ // idx: (B, N, 3)
+ // weight: (B, N, 3)
+ // output:
+ // out: (B, C, N)
+
+ int bs_idx = blockIdx.z;
+ int c_idx = blockIdx.y;
+ int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
+
+ if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
+
+ weight += bs_idx * n * 3 + pt_idx * 3;
+ points += bs_idx * c * m + c_idx * m;
+ idx += bs_idx * n * 3 + pt_idx * 3;
+ out += bs_idx * c * n + c_idx * n;
+
+ out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] + weight[2] * points[idx[2]];
+}
+
+void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n,
+ const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream) {
+ // points: (B, C, M)
+ // idx: (B, N, 3)
+ // weight: (B, N, 3)
+ // output:
+ // out: (B, C, N)
+
+ cudaError_t err;
+ dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
+ dim3 threads(THREADS_PER_BLOCK);
+ three_interpolate_kernel_fast<<>>(b, c, m, n, points, idx, weight, out);
+
+ err = cudaGetLastError();
+ if (cudaSuccess != err) {
+ fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
+ exit(-1);
+ }
+}
+
+
+__global__ void three_interpolate_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out,
+ const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ grad_points) {
+ // grad_out: (B, C, N)
+ // weight: (B, N, 3)
+ // output:
+ // grad_points: (B, C, M)
+
+ int bs_idx = blockIdx.z;
+ int c_idx = blockIdx.y;
+ int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
+
+ if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
+
+ grad_out += bs_idx * c * n + c_idx * n + pt_idx;
+ weight += bs_idx * n * 3 + pt_idx * 3;
+ grad_points += bs_idx * c * m + c_idx * m;
+ idx += bs_idx * n * 3 + pt_idx * 3;
+
+
+ atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]);
+ atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]);
+ atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]);
+}
+
+void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out,
+ const int *idx, const float *weight, float *grad_points, cudaStream_t stream) {
+ // grad_out: (B, C, N)
+ // weight: (B, N, 3)
+ // output:
+ // grad_points: (B, C, M)
+
+ cudaError_t err;
+ dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
+ dim3 threads(THREADS_PER_BLOCK);
+ three_interpolate_grad_kernel_fast<<>>(b, c, n, m, grad_out, idx, weight, grad_points);
+
+ err = cudaGetLastError();
+ if (cudaSuccess != err) {
+ fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
+ exit(-1);
+ }
+}
\ No newline at end of file
diff --git a/lib/pointnet2/src/interpolate_gpu.h b/lib/pointnet2/src/interpolate_gpu.h
new file mode 100644
index 0000000000000000000000000000000000000000..f1771087c5e4146e3c5775d3b929ebffffd11ccb
--- /dev/null
+++ b/lib/pointnet2/src/interpolate_gpu.h
@@ -0,0 +1,30 @@
+#ifndef _INTERPOLATE_GPU_H
+#define _INTERPOLATE_GPU_H
+
+#include
+#include
+#include
+#include
+
+
+void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor,
+ at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor);
+
+void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown,
+ const float *known, float *dist2, int *idx, cudaStream_t stream);
+
+
+void three_interpolate_wrapper_fast(int b, int c, int m, int n, at::Tensor points_tensor,
+ at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor);
+
+void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n,
+ const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream);
+
+
+void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, at::Tensor grad_out_tensor,
+ at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_points_tensor);
+
+void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out,
+ const int *idx, const float *weight, float *grad_points, cudaStream_t stream);
+
+#endif
diff --git a/lib/pointnet2/src/pointnet2_api.cpp b/lib/pointnet2/src/pointnet2_api.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..d91f0f2176a6080624f071e5535fe509a0ac83c4
--- /dev/null
+++ b/lib/pointnet2/src/pointnet2_api.cpp
@@ -0,0 +1,24 @@
+#include
+#include
+
+#include "ball_query_gpu.h"
+#include "group_points_gpu.h"
+#include "sampling_gpu.h"
+#include "interpolate_gpu.h"
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("ball_query_wrapper", &ball_query_wrapper_fast, "ball_query_wrapper_fast");
+
+ m.def("group_points_wrapper", &group_points_wrapper_fast, "group_points_wrapper_fast");
+ m.def("group_points_grad_wrapper", &group_points_grad_wrapper_fast, "group_points_grad_wrapper_fast");
+
+ m.def("gather_points_wrapper", &gather_points_wrapper_fast, "gather_points_wrapper_fast");
+ m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper_fast, "gather_points_grad_wrapper_fast");
+
+ m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper");
+
+ m.def("three_nn_wrapper", &three_nn_wrapper_fast, "three_nn_wrapper_fast");
+ m.def("three_interpolate_wrapper", &three_interpolate_wrapper_fast, "three_interpolate_wrapper_fast");
+ m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_fast, "three_interpolate_grad_wrapper_fast");
+}
diff --git a/lib/pointnet2/src/sampling.cpp b/lib/pointnet2/src/sampling.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..5f54daa763ed66240c17ba6254ee9d5a39b6dfc0
--- /dev/null
+++ b/lib/pointnet2/src/sampling.cpp
@@ -0,0 +1,45 @@
+#include
+#include
+#include
+#include
+#include
+#include "sampling_gpu.h"
+
+
+
+int gather_points_wrapper_fast(int b, int c, int n, int npoints,
+ at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor){
+ const float *points = points_tensor.data();
+ const int *idx = idx_tensor.data();
+ float *out = out_tensor.data();
+
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+ gather_points_kernel_launcher_fast(b, c, n, npoints, points, idx, out, stream);
+ return 1;
+}
+
+
+int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints,
+ at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) {
+
+ const float *grad_out = grad_out_tensor.data();
+ const int *idx = idx_tensor.data();
+ float *grad_points = grad_points_tensor.data();
+
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+ gather_points_grad_kernel_launcher_fast(b, c, n, npoints, grad_out, idx, grad_points, stream);
+ return 1;
+}
+
+
+int furthest_point_sampling_wrapper(int b, int n, int m,
+ at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) {
+
+ const float *points = points_tensor.data();
+ float *temp = temp_tensor.data();
+ int *idx = idx_tensor.data();
+
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+ furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx, stream);
+ return 1;
+}
diff --git a/lib/pointnet2/src/sampling_gpu.cu b/lib/pointnet2/src/sampling_gpu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..9e49a60dd6a80449be4c6c0d0d710be7b5fe9cd5
--- /dev/null
+++ b/lib/pointnet2/src/sampling_gpu.cu
@@ -0,0 +1,253 @@
+#include
+#include
+
+#include "cuda_utils.h"
+#include "sampling_gpu.h"
+
+
+__global__ void gather_points_kernel_fast(int b, int c, int n, int m,
+ const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) {
+ // points: (B, C, N)
+ // idx: (B, M)
+ // output:
+ // out: (B, C, M)
+
+ int bs_idx = blockIdx.z;
+ int c_idx = blockIdx.y;
+ int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (bs_idx >= b || c_idx >= c || pt_idx >= m) return;
+
+ out += bs_idx * c * m + c_idx * m + pt_idx;
+ idx += bs_idx * m + pt_idx;
+ points += bs_idx * c * n + c_idx * n;
+ out[0] = points[idx[0]];
+}
+
+void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints,
+ const float *points, const int *idx, float *out, cudaStream_t stream) {
+ // points: (B, C, N)
+ // idx: (B, npoints)
+ // output:
+ // out: (B, C, npoints)
+
+ cudaError_t err;
+ dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
+ dim3 threads(THREADS_PER_BLOCK);
+
+ gather_points_kernel_fast<<>>(b, c, n, npoints, points, idx, out);
+
+ err = cudaGetLastError();
+ if (cudaSuccess != err) {
+ fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
+ exit(-1);
+ }
+}
+
+__global__ void gather_points_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out,
+ const int *__restrict__ idx, float *__restrict__ grad_points) {
+ // grad_out: (B, C, M)
+ // idx: (B, M)
+ // output:
+ // grad_points: (B, C, N)
+
+ int bs_idx = blockIdx.z;
+ int c_idx = blockIdx.y;
+ int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (bs_idx >= b || c_idx >= c || pt_idx >= m) return;
+
+ grad_out += bs_idx * c * m + c_idx * m + pt_idx;
+ idx += bs_idx * m + pt_idx;
+ grad_points += bs_idx * c * n + c_idx * n;
+
+ atomicAdd(grad_points + idx[0], grad_out[0]);
+}
+
+void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints,
+ const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) {
+ // grad_out: (B, C, npoints)
+ // idx: (B, npoints)
+ // output:
+ // grad_points: (B, C, N)
+
+ cudaError_t err;
+ dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
+ dim3 threads(THREADS_PER_BLOCK);
+
+ gather_points_grad_kernel_fast<<>>(b, c, n, npoints, grad_out, idx, grad_points);
+
+ err = cudaGetLastError();
+ if (cudaSuccess != err) {
+ fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
+ exit(-1);
+ }
+}
+
+
+__device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, int idx1, int idx2){
+ const float v1 = dists[idx1], v2 = dists[idx2];
+ const int i1 = dists_i[idx1], i2 = dists_i[idx2];
+ dists[idx1] = max(v1, v2);
+ dists_i[idx1] = v2 > v1 ? i2 : i1;
+}
+
+template
+__global__ void furthest_point_sampling_kernel(int b, int n, int m,
+ const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) {
+ // dataset: (B, N, 3)
+ // tmp: (B, N)
+ // output:
+ // idx: (B, M)
+
+ if (m <= 0) return;
+ __shared__ float dists[block_size];
+ __shared__ int dists_i[block_size];
+
+ int batch_index = blockIdx.x;
+ dataset += batch_index * n * 3;
+ temp += batch_index * n;
+ idxs += batch_index * m;
+
+ int tid = threadIdx.x;
+ const int stride = block_size;
+
+ int old = 0;
+ if (threadIdx.x == 0)
+ idxs[0] = old;
+
+ __syncthreads();
+ for (int j = 1; j < m; j++) {
+ int besti = 0;
+ float best = -1;
+ float x1 = dataset[old * 3 + 0];
+ float y1 = dataset[old * 3 + 1];
+ float z1 = dataset[old * 3 + 2];
+ for (int k = tid; k < n; k += stride) {
+ float x2, y2, z2;
+ x2 = dataset[k * 3 + 0];
+ y2 = dataset[k * 3 + 1];
+ z2 = dataset[k * 3 + 2];
+ // float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
+ // if (mag <= 1e-3)
+ // continue;
+
+ float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
+ float d2 = min(d, temp[k]);
+ temp[k] = d2;
+ besti = d2 > best ? k : besti;
+ best = d2 > best ? d2 : best;
+ }
+ dists[tid] = best;
+ dists_i[tid] = besti;
+ __syncthreads();
+
+ if (block_size >= 1024) {
+ if (tid < 512) {
+ __update(dists, dists_i, tid, tid + 512);
+ }
+ __syncthreads();
+ }
+
+ if (block_size >= 512) {
+ if (tid < 256) {
+ __update(dists, dists_i, tid, tid + 256);
+ }
+ __syncthreads();
+ }
+ if (block_size >= 256) {
+ if (tid < 128) {
+ __update(dists, dists_i, tid, tid + 128);
+ }
+ __syncthreads();
+ }
+ if (block_size >= 128) {
+ if (tid < 64) {
+ __update(dists, dists_i, tid, tid + 64);
+ }
+ __syncthreads();
+ }
+ if (block_size >= 64) {
+ if (tid < 32) {
+ __update(dists, dists_i, tid, tid + 32);
+ }
+ __syncthreads();
+ }
+ if (block_size >= 32) {
+ if (tid < 16) {
+ __update(dists, dists_i, tid, tid + 16);
+ }
+ __syncthreads();
+ }
+ if (block_size >= 16) {
+ if (tid < 8) {
+ __update(dists, dists_i, tid, tid + 8);
+ }
+ __syncthreads();
+ }
+ if (block_size >= 8) {
+ if (tid < 4) {
+ __update(dists, dists_i, tid, tid + 4);
+ }
+ __syncthreads();
+ }
+ if (block_size >= 4) {
+ if (tid < 2) {
+ __update(dists, dists_i, tid, tid + 2);
+ }
+ __syncthreads();
+ }
+ if (block_size >= 2) {
+ if (tid < 1) {
+ __update(dists, dists_i, tid, tid + 1);
+ }
+ __syncthreads();
+ }
+
+ old = dists_i[0];
+ if (tid == 0)
+ idxs[j] = old;
+ }
+}
+
+void furthest_point_sampling_kernel_launcher(int b, int n, int m,
+ const float *dataset, float *temp, int *idxs, cudaStream_t stream) {
+ // dataset: (B, N, 3)
+ // tmp: (B, N)
+ // output:
+ // idx: (B, M)
+
+ cudaError_t err;
+ unsigned int n_threads = opt_n_threads(n);
+
+ switch (n_threads) {
+ case 1024:
+ furthest_point_sampling_kernel<1024><<>>(b, n, m, dataset, temp, idxs); break;
+ case 512:
+ furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, temp, idxs); break;
+ case 256:
+ furthest_point_sampling_kernel<256><<>>(b, n, m, dataset, temp, idxs); break;
+ case 128:
+ furthest_point_sampling_kernel<128><<>>(b, n, m, dataset, temp, idxs); break;
+ case 64:
+ furthest_point_sampling_kernel<64><<>>(b, n, m, dataset, temp, idxs); break;
+ case 32:
+ furthest_point_sampling_kernel<32><<>>(b, n, m, dataset, temp, idxs); break;
+ case 16:
+ furthest_point_sampling_kernel<16><<>>(b, n, m, dataset, temp, idxs); break;
+ case 8:
+ furthest_point_sampling_kernel<8><<>>(b, n, m, dataset, temp, idxs); break;
+ case 4:
+ furthest_point_sampling_kernel<4><<>>(b, n, m, dataset, temp, idxs); break;
+ case 2:
+ furthest_point_sampling_kernel<2><<>>(b, n, m, dataset, temp, idxs); break;
+ case 1:
+ furthest_point_sampling_kernel<1><<>>(b, n, m, dataset, temp, idxs); break;
+ default:
+ furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, temp, idxs);
+ }
+
+ err = cudaGetLastError();
+ if (cudaSuccess != err) {
+ fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
+ exit(-1);
+ }
+}
diff --git a/lib/pointnet2/src/sampling_gpu.h b/lib/pointnet2/src/sampling_gpu.h
new file mode 100644
index 0000000000000000000000000000000000000000..6200c5914e434ecd2fc3b36313985805f6dbe0cc
--- /dev/null
+++ b/lib/pointnet2/src/sampling_gpu.h
@@ -0,0 +1,29 @@
+#ifndef _SAMPLING_GPU_H
+#define _SAMPLING_GPU_H
+
+#include
+#include
+#include
+
+
+int gather_points_wrapper_fast(int b, int c, int n, int npoints,
+ at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor);
+
+void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints,
+ const float *points, const int *idx, float *out, cudaStream_t stream);
+
+
+int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints,
+ at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor);
+
+void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints,
+ const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream);
+
+
+int furthest_point_sampling_wrapper(int b, int n, int m,
+ at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor);
+
+void furthest_point_sampling_kernel_launcher(int b, int n, int m,
+ const float *dataset, float *temp, int *idxs, cudaStream_t stream);
+
+#endif
diff --git a/model/LLM/__init__.py b/model/LLM/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e8eb9e9325f1906f28a9d60d967ff76963ff1a8
--- /dev/null
+++ b/model/LLM/__init__.py
@@ -0,0 +1 @@
+from . import onellm
\ No newline at end of file
diff --git a/model/LLM/__pycache__/__init__.cpython-310.pyc b/model/LLM/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e70f6416d504770062ceb50661a6094181c47ea2
Binary files /dev/null and b/model/LLM/__pycache__/__init__.cpython-310.pyc differ
diff --git a/model/LLM/__pycache__/__init__.cpython-39.pyc b/model/LLM/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..be815601063517a817e23e64c9a6208e4e66d833
Binary files /dev/null and b/model/LLM/__pycache__/__init__.cpython-39.pyc differ
diff --git a/model/LLM/__pycache__/onellm.cpython-310.pyc b/model/LLM/__pycache__/onellm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ccf829243e41031a865186ba965b5d98d44174f1
Binary files /dev/null and b/model/LLM/__pycache__/onellm.cpython-310.pyc differ
diff --git a/model/LLM/__pycache__/onellm.cpython-39.pyc b/model/LLM/__pycache__/onellm.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..320f3ae803542ebcea9d3414a83ae4f5e5845455
Binary files /dev/null and b/model/LLM/__pycache__/onellm.cpython-39.pyc differ
diff --git a/model/LLM/onellm.py b/model/LLM/onellm.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a5195737c0448e3d83c3301acbd3fce3bcd0a4e
--- /dev/null
+++ b/model/LLM/onellm.py
@@ -0,0 +1,495 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the GNU General Public License version 3.
+
+from typing import Optional, Tuple
+from dataclasses import dataclass
+import math
+import functools
+import copy
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+import fairscale.nn.model_parallel.initialize as fs_init
+from fairscale.nn.model_parallel.layers import (
+ ParallelEmbedding,
+ RowParallelLinear,
+ ColumnParallelLinear,
+)
+from ..components import RMSNorm
+from flash_attn import flash_attn_func
+
+import open_clip
+
+
+default_linear_init = nn.init.xavier_uniform_
+
+
+@dataclass
+class ModelArgs:
+ dim: int = 512
+ n_layers: int = 8
+ n_heads: int = 8
+ vocab_size: int = -1 # defined later by tokenizer
+ multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
+ norm_eps: float = 1e-5
+
+ max_batch_size: int = 32
+ max_seq_len: int = 2048
+
+
+def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)
+ [: (dim // 2)].float() / dim))
+ t = torch.arange(end, device=freqs.device) # type: ignore
+ freqs = torch.outer(t, freqs).float() # type: ignore
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
+ return freqs_cis
+
+
+def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
+ ndim = x.ndim
+ assert 0 <= 1 < ndim
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
+ shape = [d if i == 1 or i == ndim -
+ 1 else 1 for i, d in enumerate(x.shape)]
+ return freqs_cis.view(*shape)
+
+
+def apply_rotary_emb(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
+ return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+class Attention(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+
+ self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()
+ self.head_dim = args.dim // args.n_heads
+
+ self.wq = ColumnParallelLinear(
+ args.dim,
+ args.n_heads * self.head_dim,
+ bias=False,
+ gather_output=False,
+ init_method=default_linear_init,
+ )
+ self.wk = ColumnParallelLinear(
+ args.dim,
+ args.n_heads * self.head_dim,
+ bias=False,
+ gather_output=False,
+ init_method=default_linear_init,
+ )
+ self.wv = ColumnParallelLinear(
+ args.dim,
+ args.n_heads * self.head_dim,
+ bias=False,
+ gather_output=False,
+ init_method=default_linear_init,
+ )
+ self.wo = RowParallelLinear(
+ args.n_heads * self.head_dim,
+ args.dim,
+ bias=False,
+ input_is_parallel=True,
+ init_method=default_linear_init,
+ )
+
+ self.flash = True
+ self.k_cache, self.v_cache = None, None
+
+ def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], prompt=None):
+ bsz, seqlen, _ = x.shape
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
+
+ xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
+ xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
+ xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
+
+ if freqs_cis is not None:
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
+
+ if self.k_cache is None or self.v_cache is None:
+ keys, values = xk, xv
+ else:
+ self.k_cache = self.k_cache.to(xk)
+ self.v_cache = self.v_cache.to(xv)
+ self.k_cache[:bsz, start_pos: start_pos + seqlen, :, :] = xk
+ self.v_cache[:bsz, start_pos: start_pos + seqlen, :, :] = xv
+ keys = self.k_cache[:bsz, :start_pos + seqlen]
+ values = self.v_cache[:bsz, :start_pos + seqlen]
+
+ output = flash_attn_func(
+ xq, keys, values, dropout_p=0.0, causal=mask is not None)
+ output = output.contiguous().view(bsz, seqlen, -1)
+
+ return self.wo(output)
+
+ def allocate_kv_cache(self, max_batch_size: int, max_seq_len: int) -> None:
+ kv_cache_shape = (max_batch_size, max_seq_len,
+ self.n_local_heads, self.head_dim)
+ if self.k_cache is None or self.k_cache.size() != kv_cache_shape:
+ self.k_cache = torch.empty(kv_cache_shape)
+ if self.v_cache is None or self.v_cache.size() != kv_cache_shape:
+ self.v_cache = torch.empty(kv_cache_shape)
+
+ def destroy_kv_cache(self) -> None:
+ self.k_cache, self.v_cache = None, None
+
+
+class FeedForward(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ multiple_of: int,
+ ):
+ super().__init__()
+ hidden_dim = int(2 * hidden_dim / 3)
+ hidden_dim = multiple_of * \
+ ((hidden_dim + multiple_of - 1) // multiple_of)
+
+ self.w1 = ColumnParallelLinear(
+ dim, hidden_dim, bias=False, gather_output=False, init_method=default_linear_init,
+ )
+ self.w2 = RowParallelLinear(
+ hidden_dim, dim, bias=False, input_is_parallel=True, init_method=default_linear_init
+ )
+ self.w3 = ColumnParallelLinear(
+ dim, hidden_dim, bias=False, gather_output=False, init_method=default_linear_init
+ )
+
+ def _silu_gating(self, x, y):
+ return F.silu(x) * y
+
+ def forward(self, x):
+ return self.w2(self._silu_gating(self.w1(x), self.w3(x)))
+
+
+class TransformerBlock(nn.Module):
+ def __init__(self, layer_id: int, args: ModelArgs):
+ super().__init__()
+ self.n_heads = args.n_heads
+ self.dim = args.dim
+ self.head_dim = args.dim // args.n_heads
+ self.attention = Attention(args)
+ self.feed_forward = FeedForward(
+ dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of
+ )
+ self.layer_id = layer_id
+ self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
+ self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
+
+ def _forward_ffn(self, h):
+ return h + self.feed_forward(self.ffn_norm(h))
+
+ def _forward_attention(self, x, start_pos, freqs_cis, mask, prompt):
+ return x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask, prompt)
+
+ def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], prompt=None):
+ h = self._forward_attention(x, start_pos, freqs_cis, mask, prompt)
+ out = self._forward_ffn(h)
+ return out
+
+
+class Mlp(nn.Module):
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
+ """
+
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.fc2(x)
+ return x
+
+
+class Transformer(nn.Module):
+ def __init__(self, params: ModelArgs):
+ super().__init__()
+ self.params = params
+ self.vocab_size = params.vocab_size
+ self.n_layers = params.n_layers
+ self.tok_embeddings = ParallelEmbedding(
+ params.vocab_size, params.dim, init_method=nn.init.normal_,
+ )
+
+ self.layers = torch.nn.ModuleList()
+ for layer_id in range(params.n_layers):
+ self.layers.append(TransformerBlock(layer_id, params))
+
+ self.norm = RMSNorm(params.dim, eps=params.norm_eps)
+ self.output = ColumnParallelLinear(
+ params.dim, params.vocab_size, bias=False, init_method=default_linear_init,
+ )
+
+ self.freqs_cis = precompute_freqs_cis(
+ self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
+ )
+
+ # load clip
+ self.clip, _, _ = open_clip.create_model_and_transforms(
+ 'ViT-L-14', pretrained='openai')
+ for param in self.clip.parameters():
+ param.requires_grad = False
+ param.data = param.data.half()
+ self.clip.transformer = None
+
+ self.image_words = 30
+ self.cache_image_words = 0 # for inference
+
+ clip_width = self.clip.visual.conv1.out_channels
+ # create modal shared modules
+ self.resample_layers = nn.ModuleDict()
+ self.num_experts = 3
+ self.num_resample_layers = 8
+ for expert in range(self.num_experts):
+ expert = str(expert)
+ self.resample_layers[expert] = nn.ModuleList()
+ resampler_params = copy.deepcopy(params)
+ resampler_params.n_heads = 16
+ resampler_params.dim = clip_width
+ for layer_id in range(self.num_resample_layers):
+ self.resample_layers[expert].append(
+ TransformerBlock(layer_id, resampler_params))
+
+ self.conv1 = nn.ModuleDict()
+ self.positional_embedding = nn.ParameterDict()
+ self.resample_tokens = nn.ParameterDict()
+ self.clip_proj1 = nn.ModuleDict()
+ self.clip_proj2 = nn.ModuleDict()
+ self.routers = nn.ModuleDict()
+ self.start_tag = nn.ParameterDict()
+ self.end_tag = nn.ParameterDict()
+ # self.modals = ['image', 'audio', 'point', 'video', 'rgbd', 'rgbn', 'fmri', 'imu']
+ self.modals = ['image', 'audio', 'video', 'rgbd', 'rgbn', 'fmri', 'imu']
+ for modal in self.modals:
+ if modal in ['image', 'video', 'rgbn', 'rgbn']:
+ modal_tokens = 256 + 1
+ pass
+ elif modal == 'audio':
+ self.conv1[modal] = nn.Conv2d(
+ 1, clip_width, kernel_size=(16, 16), stride=(10, 10))
+ modal_tokens = 1212 + 1
+ self.positional_embedding[modal] = nn.Parameter(
+ torch.empty([modal_tokens, clip_width]))
+ nn.init.normal_(self.positional_embedding[modal], std=0.02)
+ elif modal == 'point':
+ from lib.point_utils import PointPatchEmbed
+ self.conv1[modal] = PointPatchEmbed(
+ in_channels=6, channels=clip_width)
+ modal_tokens = 1024 + 1
+ self.positional_embedding[modal] = nn.Parameter(
+ torch.empty([modal_tokens, clip_width]))
+ nn.init.normal_(self.positional_embedding[modal], std=0.02)
+ elif modal == 'fmri':
+ self.conv1[modal] = nn.Linear(15724, 8192)
+ self.positional_embedding[modal] = nn.Parameter(
+ torch.empty([8+1, clip_width]))
+ nn.init.normal_(self.positional_embedding[modal], std=0.02)
+ elif modal == 'imu':
+ self.conv1[modal] = nn.Conv1d(
+ in_channels=6, out_channels=clip_width, kernel_size=10, bias=False)
+ self.positional_embedding[modal] = nn.Parameter(
+ torch.empty([391+1, clip_width]))
+ nn.init.normal_(self.positional_embedding[modal], std=0.02)
+
+ self.routers[modal] = Mlp(
+ clip_width, clip_width * 4, self.num_experts)
+
+ self.resample_tokens[modal] = nn.Parameter(
+ torch.empty([1, 30, resampler_params.dim]))
+ nn.init.normal_(self.resample_tokens[modal], std=0.02)
+
+ self.clip_proj1[modal] = nn.Sequential(
+ nn.Linear(clip_width, resampler_params.dim),
+ nn.LayerNorm(resampler_params.dim))
+
+ self.clip_proj2[modal] = nn.Sequential(
+ nn.Linear(resampler_params.dim, params.dim),
+ nn.LayerNorm(params.dim))
+
+ self.start_tag[modal] = nn.Parameter(torch.rand(1, 1, params.dim))
+ self.end_tag[modal] = nn.Parameter(torch.rand(1, 1, params.dim))
+
+ # @torch.no_grad()
+
+ def clip_encode_image(self, x, modal='image'):
+ # shape = [*, width, grid ** 2]
+ x = x.reshape(x.shape[0], x.shape[1], -1)
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
+
+ x = torch.cat([self.clip.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1,
+ x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
+
+ # use pretrained pos embeding for rest modalities
+ pos_embedding = self.clip.visual.positional_embedding
+ if modal in ['audio', 'point', 'fmri', 'imu']:
+ pos_embedding = self.positional_embedding[modal]
+
+ x = x + pos_embedding.to(x.dtype)
+ x = self.clip.visual.ln_pre(x)
+
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.clip.visual.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+ # preserve all spatial tokens
+ x = self.clip.visual.ln_post(x[:, :, :])
+
+ # if self.clip.visual.proj is not None:
+ # x = x @ self.clip.visual.proj
+
+ return x
+
+ def encode_image(self, x, modal='image'):
+ bsz = x.size(0)
+ T = 1
+ if modal in ['image']:
+ # modified from CLIP
+ x = self.clip.visual.conv1(x) # shape = [*, width, grid, grid]
+ elif modal in ['audio', 'imu']:
+ x = self.conv1[modal](x)
+ elif modal == 'point':
+ # [B, 16384, 6] -> [B, 1024, 1024, 1]
+ x = self.conv1[modal](x.float()).to(x.dtype)
+ elif modal in ['video', 'rgbd', 'rgbn']:
+ # [B, 15, 3, 224, 224]
+ B, T = x.shape[:2]
+ bsz = B * T
+ x = x.reshape(bsz, *x.shape[2:])
+ x = self.clip.visual.conv1(x)
+ elif modal == 'fmri':
+ x = self.conv1[modal](x)
+ # [B, 1, 8196] -> [B, 1024, 8]
+ x = x.reshape(x.size(0), self.clip.visual.conv1.out_channels, -1)
+
+ image_feats = self.clip_encode_image(x, modal=modal)
+ # take mean on time dimension
+ # all inputs are reduced to [B, L, D]
+ bsz = int(bsz / T)
+ image_feats = image_feats.reshape(
+ bsz, T, *image_feats.shape[1:]).mean(dim=1)
+
+ image_feats = self.clip_proj1[modal](image_feats)
+ image_feats = torch.cat(
+ [self.resample_tokens[modal].repeat(bsz, 1, 1), image_feats], dim=1)
+
+ # routing modalites
+ # [B, L, D]->[B, L, N]
+ routing_weights = self.routers[modal](image_feats).sigmoid()
+ routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
+
+ image_feats_experts = []
+ for expert_id in range(self.num_experts):
+ image_feats_expert = image_feats
+ for layer in self.resample_layers[str(expert_id)]:
+ image_feats_expert = layer(image_feats_expert, 0, None, None)
+
+ image_feats_expert = image_feats_expert[:, :self.resample_tokens[modal].size(1)]
+ routing_weight = routing_weights[:, :self.resample_tokens[modal].size(
+ 1), expert_id]
+ # [B, L, D] * [B, L, 1]
+ image_feats_expert = image_feats_expert * routing_weight[:, :, None]
+
+ image_feats_experts.append(image_feats_expert)
+
+ image_feats = sum(image_feats_experts)
+ image_feats = self.clip_proj2[modal](image_feats)
+
+ return image_feats
+
+ def forward(self, examples, image=None, modal='image'):
+ self._destroy_kv_cache() # training always disables kv cache
+ modal = modal[0]
+ _bsz, seqlen = examples.shape
+ h = self.tok_embeddings(examples)
+ self.freqs_cis = self.freqs_cis.to(h.device)
+
+ start_pos = 0
+ prefix_len = 0
+ if image is not None:
+ h_bos, h_caption = h[:, :1], h[:, 1:]
+ image_tokens = self.encode_image(image, modal)
+ h = torch.cat((h_bos, self.start_tag[modal].expand(
+ _bsz, -1, -1), image_tokens, self.end_tag[modal].expand(_bsz, -1, -1), h_caption), dim=1)
+ # bos + image token + start_tag[modal], end_tag[modal] is used for caption generation
+ prefix_len = image_tokens.shape[1] + 1 + 1
+ seqlen = h.shape[1]
+
+ freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen]
+ mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=h.device)
+ mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
+ for layer in self.layers:
+ h = layer(h, start_pos, freqs_cis, mask)
+ h = self.norm(h)
+ output = self.output(h[:, prefix_len:, :])
+ return output
+
+ @torch.inference_mode()
+ def forward_inference(self, tokens: torch.Tensor, start_pos: int, image=None, modal='image'):
+ modal = modal[0] if isinstance(modal, list) else modal
+ _bsz, seqlen = tokens.shape
+ if start_pos == 0:
+ # kv cache will not re-allocate if size is unchanged
+ self._allocate_kv_cache(_bsz)
+ h = self.tok_embeddings(tokens)
+ self.freqs_cis = self.freqs_cis.to(h.device)
+
+ if image is not None:
+ h_bos, h_caption = h[:, :1], h[:, 1:]
+ image_tokens = self.encode_image(image, modal)
+ self.cache_image_words = image_tokens.shape[1]
+ h = torch.cat((h_bos, self.start_tag[modal].repeat(_bsz, 1, 1), image_tokens, self.end_tag[modal].repeat(_bsz, 1, 1), h_caption), dim=1)
+ seqlen = h.shape[1]
+ freqs_cis = self.freqs_cis[0: seqlen]
+ else:
+ if start_pos == 0:
+ self.cache_image_words = 0
+ freqs_cis = self.freqs_cis[0: seqlen]
+ else:
+ # if image was not None when start_pos=0,
+ # the offset should be added to start_pos within later forward_inference calls
+ start_pos = start_pos + self.cache_image_words
+ freqs_cis = self.freqs_cis[start_pos: start_pos + seqlen]
+
+ # freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
+
+ mask = None
+ if seqlen > 1:
+ mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)
+ mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
+
+ for layer in self.layers:
+ h = layer(h, start_pos, freqs_cis, mask)
+ h = self.norm(h)
+ output = self.output(h[:, -1, :]) # only compute last logits
+ return output.float()
+
+ def _allocate_kv_cache(self, max_batch_size: int) -> None:
+ for layer in self.layers:
+ layer.attention.allocate_kv_cache(
+ max_batch_size, self.params.max_seq_len)
+
+ def _destroy_kv_cache(self) -> None:
+ for layer in self.layers:
+ layer.attention.destroy_kv_cache()
diff --git a/model/__init__.py b/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/__pycache__/__init__.cpython-310.pyc b/model/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ab67f64cfe739a7a1c51327e5e7a0ea2afc50cd9
Binary files /dev/null and b/model/__pycache__/__init__.cpython-310.pyc differ
diff --git a/model/__pycache__/__init__.cpython-39.pyc b/model/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bea7e8cd12224f18eb3eefbc92abf61852979fab
Binary files /dev/null and b/model/__pycache__/__init__.cpython-39.pyc differ
diff --git a/model/__pycache__/components.cpython-39.pyc b/model/__pycache__/components.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dfbf25224cf34dab4fa2f85fff462d2dbef6b4d6
Binary files /dev/null and b/model/__pycache__/components.cpython-39.pyc differ
diff --git a/model/__pycache__/meta.cpython-310.pyc b/model/__pycache__/meta.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3fbd547ed81e5dc062ca75c125fe8c8a668b5ead
Binary files /dev/null and b/model/__pycache__/meta.cpython-310.pyc differ
diff --git a/model/__pycache__/meta.cpython-39.pyc b/model/__pycache__/meta.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b69c01b1098d55637fb39f5be3aed62ddaf7cf43
Binary files /dev/null and b/model/__pycache__/meta.cpython-39.pyc differ
diff --git a/model/__pycache__/tokenizer.cpython-310.pyc b/model/__pycache__/tokenizer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9a4452629f6f6edcb5522834a8e5bbdfc825b48e
Binary files /dev/null and b/model/__pycache__/tokenizer.cpython-310.pyc differ
diff --git a/model/__pycache__/tokenizer.cpython-39.pyc b/model/__pycache__/tokenizer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c1a8d58048c6364146decfe9c883d63dc197e359
Binary files /dev/null and b/model/__pycache__/tokenizer.cpython-39.pyc differ
diff --git a/model/components.py b/model/components.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c8bc4e88484950988aaad4faf6d34ec1a4ec8bf
--- /dev/null
+++ b/model/components.py
@@ -0,0 +1,57 @@
+import warnings
+import torch
+import torch.nn as nn
+
+try:
+ from apex.normalization import FusedRMSNorm as RMSNorm
+except ImportError:
+ warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
+
+ class RMSNorm(torch.nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-6):
+ """
+ Initialize the RMSNorm normalization layer.
+
+ Args:
+ dim (int): The dimension of the input tensor.
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
+
+ Attributes:
+ eps (float): A small value added to the denominator for numerical stability.
+ weight (nn.Parameter): Learnable scaling parameter.
+
+ """
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def _norm(self, x):
+ """
+ Apply the RMSNorm normalization to the input tensor.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: The normalized tensor.
+
+ """
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x):
+ """
+ Forward pass through the RMSNorm layer.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: The output tensor after applying RMSNorm.
+
+ """
+ output = self._norm(x.float()).type_as(x)
+ return output * self.weight
+
+
+
+
diff --git a/model/meta.py b/model/meta.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0ab6daaa14337633f9d3261d78248683d04c930
--- /dev/null
+++ b/model/meta.py
@@ -0,0 +1,175 @@
+from typing import List
+import torch
+import torch.nn as nn
+import json
+import os
+from .tokenizer import Tokenizer
+from . import LLM
+
+from fairscale.nn.model_parallel import initialize as fs_init
+
+
+class MetaModel(nn.Module):
+
+ def __init__(self, llama_type, llama_config, llama_ckpt_dir=None, tokenizer_path=None):
+ super().__init__()
+
+ self.criterion = torch.nn.CrossEntropyLoss(ignore_index=0)
+
+ ModelArgs = LLM.__dict__[llama_type].ModelArgs
+ Transformer = LLM.__dict__[llama_type].Transformer
+
+ with open(llama_config, "r") as f:
+ params = json.loads(f.read())
+ model_args: ModelArgs = ModelArgs(
+ max_seq_len=2048, max_batch_size=32, **params
+ )
+ self.tokenizer = Tokenizer(model_path=tokenizer_path)
+ model_args.vocab_size = self.tokenizer.n_words
+
+ model = Transformer(model_args)
+ mp_rank = fs_init.get_model_parallel_rank()
+ if llama_ckpt_dir is not None:
+ ckpt_path = os.path.join(llama_ckpt_dir, f"consolidated.{mp_rank:02d}.pth")
+ if os.path.exists(ckpt_path):
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
+ msg = model.load_state_dict(checkpoint, strict=False)
+ print(msg)
+ else:
+ print(f'Checkpoint not found at {ckpt_path}')
+ self.llma = model
+ for name, param in self.named_parameters():
+ if param.requires_grad:
+ print(f"Trainable param: {name}, {param.shape}, {param.dtype}")
+ count = sum(p.numel() for p in self.parameters() if p.requires_grad)
+ print(f"Parameter count : {count}")
+
+ def forward(self, examples, labels, image=None, modal='image'):
+ output = self.llma(examples, image=image, modal=modal)
+ output = output[:, :-1, :]
+ labels = labels[:, 1:]
+
+ if labels.sum() == 0:
+ c_loss = output.mean() * 0
+ else:
+ c_loss = self.criterion(output.reshape(-1, 32000), labels.flatten())
+
+ return c_loss
+
+ def generate(
+ self,
+ prompts: List[str],
+ images,
+ max_gen_len: int,
+ temperature: float = 0.8,
+ top_p: float = 0.95,
+ modal = ['image'],
+ ) -> List[str]:
+ bsz = len(prompts)
+ params = self.llma.params
+ assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
+
+ prompt_tokens = [self.tokenizer.encode(
+ x, bos=True, eos=False) for x in prompts]
+
+ min_prompt_size = min([len(t) for t in prompt_tokens])
+ max_prompt_size = max([len(t) for t in prompt_tokens])
+
+ total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
+
+ tokens = torch.full(
+ (bsz, total_len), self.tokenizer.pad_id).cuda().long()
+ for k, t in enumerate(prompt_tokens):
+ tokens[k, : len(t)] = torch.tensor(t).long()
+ input_text_mask = tokens != self.tokenizer.pad_id
+ start_pos = min_prompt_size
+ prev_pos = 0
+ for cur_pos in range(start_pos, total_len):
+ logits = self.llma.forward_inference(tokens[:, prev_pos:cur_pos], prev_pos, images if prev_pos == 0 else None, modal=modal)
+ if temperature > 0:
+ probs = torch.softmax(logits / temperature, dim=-1)
+ next_token = self.sample_top_p(probs, top_p)
+ else:
+ next_token = torch.argmax(logits, dim=-1)
+ next_token = next_token.reshape(-1)
+ # only replace token if prompt has already been generated
+ next_token = torch.where(
+ input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
+ )
+ tokens[:, cur_pos] = next_token
+ prev_pos = cur_pos
+
+ decoded = []
+ for i, t in enumerate(tokens.tolist()):
+ # cut to max gen len
+ t = t[: len(prompt_tokens[i]) + max_gen_len]
+ # cut to eos tok if any
+ try:
+ t = t[: t.index(self.tokenizer.eos_id)]
+ except ValueError:
+ pass
+ decoded.append(self.tokenizer.decode(t))
+ return decoded
+
+ @torch.inference_mode()
+ def stream_generate(
+ self,
+ prompt: str,
+ images,
+ max_gen_len: int,
+ temperature: float = 0.8,
+ top_p: float = 0.95,
+ modal = ['image'],
+ ):
+ params = self.llma.params
+
+ prompt_tokens = self.tokenizer.encode(prompt, bos=True, eos=False)
+ # truncate from the left. leave some space for generation.
+ max_seq_len = params.max_seq_len
+ if images is not None:
+ max_seq_len -= self.llma.image_words
+
+ max_prompt_size = max_seq_len - max_gen_len
+ prompt_tokens = prompt_tokens[-max_prompt_size:]
+
+ prompt_size = len(prompt_tokens)
+
+ total_len = min(max_seq_len, max_gen_len + prompt_size)
+
+ tokens = torch.full([total_len], 0).cuda().long()
+
+ tokens[:len(prompt_tokens)] = torch.tensor(prompt_tokens).long()
+ start_pos = prompt_size
+ prev_pos = 0
+ generate_until = start_pos
+ for cur_pos in range(start_pos, total_len):
+ logits = self.llma.forward_inference(tokens[None, prev_pos:cur_pos], prev_pos, images if prev_pos == 0 else None, modal = modal)
+ if temperature > 0:
+ probs = torch.softmax(logits / temperature, dim=-1)
+ next_token = self.sample_top_p(probs, top_p)
+ else:
+ next_token = torch.argmax(logits, dim=-1)
+ next_token = next_token.item()
+
+ if next_token == self.tokenizer.eos_id:
+ break
+
+ tokens[cur_pos] = next_token
+ prev_pos = cur_pos
+ generate_until = cur_pos + 1
+ yield {"text": self.tokenizer.decode(tokens[start_pos:generate_until].tolist()), "end_of_content": False}
+
+ yield {"text": self.tokenizer.decode(tokens[start_pos:generate_until].tolist()), "end_of_content": True}
+
+ def sample_top_p(self, probs, p):
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
+ mask = probs_sum - probs_sort > p
+ probs_sort[mask] = 0.0
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
+ next_token = torch.multinomial(probs_sort, num_samples=1)
+ next_token = torch.gather(probs_idx, -1, next_token)
+ return next_token
+
+ def get_image_words(self):
+ return self.llma.image_words
\ No newline at end of file
diff --git a/model/tokenizer.py b/model/tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4315856eea5c4318499c8909898252902252f30
--- /dev/null
+++ b/model/tokenizer.py
@@ -0,0 +1,40 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the GNU General Public License version 3.
+
+from sentencepiece import SentencePieceProcessor
+from logging import getLogger
+from typing import List
+import os
+
+
+logger = getLogger()
+
+
+class Tokenizer:
+ def __init__(self, model_path: str):
+ # reload tokenizer
+ assert os.path.isfile(model_path), model_path
+ self.sp_model = SentencePieceProcessor(model_file=model_path)
+ logger.info(f"Reloaded SentencePiece model from {model_path}")
+
+ # BOS / EOS token IDs
+ self.n_words: int = self.sp_model.vocab_size()
+ self.bos_id: int = self.sp_model.bos_id()
+ self.eos_id: int = self.sp_model.eos_id()
+ self.pad_id: int = self.sp_model.pad_id()
+ logger.info(
+ f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
+ )
+ assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
+
+ def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
+ assert type(s) is str
+ t = self.sp_model.encode(s)
+ if bos:
+ t = [self.bos_id] + t
+ if eos:
+ t = t + [self.eos_id]
+ return t
+
+ def decode(self, t: List[int]) -> str:
+ return self.sp_model.decode(t)
diff --git a/requirements.txt b/requirements.txt
index ff5a872a773e1619013dc49c7be53ad722943b40..ce74fdd1f2242fc4e7bc50f084f7030081836fa5 100755
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,8 +1,13 @@
---extra-index-url https://download.pytorch.org/whl/cu113
-torch==1.12.0+cu113
+--extra-index-url https://download.pytorch.org/whl/cu117
+torch==2.0.0+cu117
+packaging
fairscale
sentencepiece
Pillow
huggingface_hub
-git+https://github.com/csuhan/timm_0_3_2.git
-git+https://github.com/openai/CLIP.git
\ No newline at end of file
+open_clip_torch
+pytorchvideo==0.1.5
+torchaudio
+matplotlib
+flash-attn
+gradio
\ No newline at end of file
diff --git a/util/__pycache__/misc.cpython-310.pyc b/util/__pycache__/misc.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4caa262729d9934200c8f44f3ca67d0913580474
Binary files /dev/null and b/util/__pycache__/misc.cpython-310.pyc differ
diff --git a/util/__pycache__/misc.cpython-39.pyc b/util/__pycache__/misc.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9f464da6e5db2871e7f85f496f2a0df542be9804
Binary files /dev/null and b/util/__pycache__/misc.cpython-39.pyc differ
diff --git a/util/lr_sched.py b/util/lr_sched.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc4624f4fb441ea7e37e50857813cb149887a0c0
--- /dev/null
+++ b/util/lr_sched.py
@@ -0,0 +1,42 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+
+def adjust_learning_rate(optimizer, it, args):
+ """Decay the learning rate with half-cycle cosine after warmup"""
+ if it < args.warmup_iters: # 1) linear warmup for warmup_iters steps
+ lr = args.lr * it / args.warmup_iters
+ elif it > args.lr_decay_iters: # 2) if it > lr_decay_iters, return min learning rate
+ lr = args.min_lr
+ else: # 3) in between, use cosine decay down to min learning rate
+ decay_ratio = (it - args.warmup_iters) / (args.lr_decay_iters - args.warmup_iters)
+ assert 0 <= decay_ratio <= 1
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
+ lr = args.min_lr + (args.lr - args.min_lr) * coeff
+
+ for param_group in optimizer.param_groups:
+ if "lr_scale" in param_group:
+ param_group["lr"] = lr * param_group["lr_scale"]
+ else:
+ param_group["lr"] = lr
+ return lr
+
+
+def adjust_learning_rate_epoch(optimizer, epoch, args):
+ """Decay the learning rate with half-cycle cosine after warmup"""
+ if epoch < args.warmup_epochs:
+ lr = args.lr * epoch / args.warmup_epochs
+ else:
+ lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
+ (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
+ for param_group in optimizer.param_groups:
+ if "lr_scale" in param_group:
+ param_group["lr"] = lr * param_group["lr_scale"]
+ else:
+ param_group["lr"] = lr
+ return lr
+
diff --git a/util/misc.py b/util/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..cea0d87e40afd9b8be34ef99da7b1409cb1e43ba
--- /dev/null
+++ b/util/misc.py
@@ -0,0 +1,516 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# DeiT: https://github.com/facebookresearch/deit
+# BEiT: https://github.com/microsoft/unilm/tree/master/beit
+# --------------------------------------------------------
+
+import builtins
+import datetime
+import os
+import glob
+import time
+from collections import defaultdict, deque
+from pathlib import Path
+import subprocess
+
+import torch
+import torch.distributed as dist
+from torch import inf
+from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+from torch.distributed.fsdp import (
+ FullyShardedDataParallel as FSDP,
+ StateDictType,
+ FullStateDictConfig,
+)
+from torch.distributed._shard.api import load_with_process_group
+
+from fairscale.nn.model_parallel import initialize as fs_init
+
+from types import TracebackType
+from typing import Any, Optional
+import torch
+import torch.nn as nn
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size=20, fmt=None):
+ if fmt is None:
+ fmt = "{median:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window_size)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value)
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if v is None:
+ continue
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError("'{}' object has no attribute '{}'".format(
+ type(self).__name__, attr))
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ loss_str.append(
+ "{}: {}".format(name, str(meter))
+ )
+ return self.delimiter.join(loss_str)
+
+ def synchronize_between_processes(self):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def log_every(self, iterable, print_freq, header=None, start_iter=0):
+ i = start_iter
+ if not header:
+ header = ''
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
+ data_time = SmoothedValue(fmt='{avg:.4f}')
+ log_msg = [
+ header,
+ '[{0' + '}/{1}]',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}'
+ ]
+ if torch.cuda.is_available():
+ log_msg.append('max mem: {memory:.0f}')
+ log_msg = self.delimiter.join(log_msg)
+ MB = 1024.0 * 1024.0
+ for obj in iterable:
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ if i % print_freq == 0:
+ try:
+ total_len = len(iterable)
+ except:
+ total_len = "unknown"
+ if torch.cuda.is_available():
+ print(log_msg.format(
+ i, total_len,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB))
+ else:
+ print(log_msg.format(
+ i, total_len,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time)))
+ i += 1
+ end = time.time()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('{} Total time: {} ({:.4f} s / it)'.format(
+ header, total_time_str, total_time / len(iterable)))
+
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ builtin_print = builtins.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop('force', False)
+# force = force or (get_world_size() > 8)
+ if is_master or force:
+ now = datetime.datetime.now().time()
+ builtin_print('[{}] '.format(now), end='') # print with time stamp
+ builtin_print(*args, **kwargs)
+
+ builtins.print = print
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+ if is_main_process():
+ torch.save(*args, **kwargs)
+
+def init_distributed_mode(args):
+ if args.dist_on_itp:
+ args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
+ args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
+ args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
+ args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
+ os.environ['LOCAL_RANK'] = str(args.gpu)
+ os.environ['RANK'] = str(args.rank)
+ os.environ['WORLD_SIZE'] = str(args.world_size)
+ # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
+ elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ['WORLD_SIZE'])
+ args.gpu = int(os.environ['LOCAL_RANK'])
+ elif 'SLURM_PROCID' in os.environ:
+ os.environ['MASTER_PORT'] = '8994'
+ while 'MASTER_ADDR' not in os.environ or len(os.environ['MASTER_ADDR'].strip()) == 0:
+ os.environ['MASTER_ADDR'] = subprocess.check_output('sinfo -Nh -n %s | head -n 1 | awk \'{print $1}\'' % os.environ['SLURM_NODELIST'], shell=True, ).decode().strip()
+ time.sleep(1)
+ print(os.environ['MASTER_ADDR'])
+ args.world_size = int(os.environ['SLURM_NPROCS'])
+ args.rank = int(os.environ['SLURM_PROCID'])
+ args.gpu = args.rank % torch.cuda.device_count()
+ args.local_rank = args.gpu
+ os.environ['LOCAL_RANK'] = str(args.gpu)
+ os.environ['WORLD_SIZE'] = str(args.world_size)
+ os.environ['RANK'] = str(args.rank)
+ else:
+ print('Not using distributed mode')
+ setup_for_distributed(is_master=True) # hack
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = 'nccl'
+ print('| distributed init (rank {}): {}, gpu {}'.format(
+ args.rank, args.dist_url, args.gpu), flush=True)
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
+ world_size=args.world_size, rank=args.rank)
+ torch.distributed.barrier()
+ setup_for_distributed(args.rank == 0)
+
+
+def init_distributed_mode1(args):
+ if args.dist_on_itp:
+ args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
+ args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
+ args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
+ args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
+ os.environ['LOCAL_RANK'] = str(args.gpu)
+ os.environ['RANK'] = str(args.rank)
+ os.environ['WORLD_SIZE'] = str(args.world_size)
+ # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
+ elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ['WORLD_SIZE'])
+ args.gpu = int(os.environ['LOCAL_RANK'])
+ elif 'SLURM_PROCID' in os.environ:
+ args.rank = int(os.environ['SLURM_PROCID'])
+ args.gpu = args.rank % torch.cuda.device_count()
+ else:
+ print('Not using distributed mode')
+ setup_for_distributed(is_master=True) # hack
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = 'nccl'
+ print('| distributed init (rank {}): {}, gpu {}'.format(
+ args.rank, args.dist_url, args.gpu), flush=True)
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
+ world_size=args.world_size, rank=args.rank)
+ torch.distributed.barrier()
+ setup_for_distributed(args.rank == 0)
+
+
+class NativeScalerWithGradNormCount:
+ state_dict_key = "amp_scaler"
+
+ def __init__(self, args):
+ self._scaler = ShardedGradScaler(enabled=args.precision in ["fp16"])
+
+ def __call__(self, loss, optimizer, model, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
+ if update_grad:
+ self._scaler.scale(loss).backward(create_graph=create_graph)
+ if clip_grad is not None:
+ assert parameters is not None
+ self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
+ # norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
+ norm = model.clip_grad_norm_(clip_grad)
+ else:
+ raise NotImplementedError("please set clip_grad to a very large value if you do not want to clip.")
+ self._scaler.unscale_(optimizer)
+ norm = get_grad_norm_(parameters)
+ self._scaler.step(optimizer)
+ self._scaler.update()
+ else:
+ with model.no_sync():
+ self._scaler.scale(loss).backward(create_graph=create_graph)
+ norm = None
+ return norm
+
+ def state_dict(self):
+ return self._scaler.state_dict()
+
+ def load_state_dict(self, state_dict):
+ self._scaler.load_state_dict(state_dict)
+
+
+def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ parameters = [p for p in parameters if p.grad is not None]
+ norm_type = float(norm_type)
+ if len(parameters) == 0:
+ return torch.tensor(0.)
+ device = parameters[0].grad.device
+ if norm_type == inf:
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
+ else:
+ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
+ return total_norm
+
+
+def save_model(output_dir, args, epoch, iteration, model, optimizer, loss_scaler, dataset_state):
+ save_dir = os.path.join(output_dir, f"epoch_{epoch}_iter_{iteration:09d}")
+ os.makedirs(save_dir, exist_ok=True)
+ with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
+ to_save = {
+ "model": model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "iter": iteration,
+ "epoch": epoch,
+ "scaler": loss_scaler.state_dict(),
+ "args": args,
+ "dataset_state": dataset_state,
+ }
+ save_path = os.path.join(
+ save_dir,
+ f"checkpoint.{dist.get_rank():05d}-of-{dist.get_world_size():05d}.pth",
+ )
+ torch.save(to_save, save_path)
+
+ if args.save_consolidated:
+ mp_rank = fs_init.get_model_parallel_rank()
+ mp_world_size = fs_init.get_model_parallel_world_size()
+ consolidated_model_save_path = os.path.join(
+ save_dir,
+ f"consolidated.{mp_rank:02d}-of-{mp_world_size:02d}.pth",
+ )
+ with FSDP.state_dict_type(
+ model,
+ StateDictType.FULL_STATE_DICT,
+ FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
+ ):
+ save_dtype = {
+ "fp16": torch.float16,
+ "bf16": torch.bfloat16,
+ "tf32": torch.float32,
+ }[args.precision]
+ consolidated_model_state_dict = {
+ k: v.to(save_dtype) for k, v in model.state_dict().items()
+ }
+ if fs_init.get_data_parallel_rank() == 0:
+ torch.save(consolidated_model_state_dict, consolidated_model_save_path)
+
+ # remove previous ckpts
+ ckpts = glob.glob(os.path.join(output_dir, "iter_*")) + glob.glob(os.path.join(output_dir, "epoch_*"))
+ ckpts.sort()
+ if len(ckpts)>2 and not args.keep_all:
+ for ckpt in ckpts[:-2]:
+ print('del', ckpt)
+ os.system(f'rm {ckpt} -rf')
+
+def load_model(args, model, optimizer, loss_scaler):
+ start_iter = 0
+ start_epoch = 0
+ if args.auto_resume:
+ ckpt_dirs = glob.glob(os.path.join(args.output_dir, "iter_*")) + glob.glob(os.path.join(args.output_dir, "epoch_*"))
+ ckpt_dirs.sort()
+ if len(ckpt_dirs) > 0:
+ args.resume = ckpt_dirs[-1]
+ if args.resume:
+ print("Resume checkpoint %s" % args.resume)
+ local_checkpoint_path = os.path.join(
+ args.resume,
+ f"checkpoint.{dist.get_rank():05d}-of-{dist.get_world_size():05d}.pth",
+ )
+ with load_with_process_group(fs_init.get_data_parallel_group()):
+ checkpoint = torch.load(local_checkpoint_path, map_location='cpu')
+ with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
+ model.load_state_dict(checkpoint['model'])
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ loss_scaler.load_state_dict(checkpoint['scaler'])
+ start_iter = int(checkpoint['iter']) + 1
+ if 'epoch' in checkpoint:
+ start_epoch = int(checkpoint['epoch'])
+ return start_epoch, start_iter
+
+def all_reduce_mean(x):
+ world_size = get_world_size()
+ if world_size > 1:
+ if isinstance(x, torch.Tensor):
+ x_reduce = x.clone().cuda()
+ else:
+ x_reduce = torch.tensor(x).cuda()
+ dist.all_reduce(x_reduce)
+ x_reduce /= world_size
+ return x_reduce.item()
+ else:
+ return x
+
+
+def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
+ decay = []
+ no_decay = []
+ for name, param in model.named_parameters():
+ if not param.requires_grad:
+ continue # frozen weights
+ #if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
+ if name.endswith(".bias") or name.endswith("norm.weight"):
+ no_decay.append(param)
+ else:
+ decay.append(param)
+ return [
+ {'params': no_decay, 'weight_decay': 0.},
+ {'params': decay, 'weight_decay': weight_decay}]
+
+
+
+
+class default_tensor_type:
+ _tensor_type_stack = [(torch.float, "cpu")]
+
+ def __init__(
+ self,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ # Only limited combinations are supported.
+ assert device is None or device in ["cpu", "cuda"]
+ assert dtype is None or dtype in [torch.float, torch.bfloat16, torch.half]
+ self.dtype, self.device = dtype, device
+
+ def __enter__(self) -> None:
+ dtype, device = self.dtype, self.device
+ if dtype is None:
+ dtype = default_tensor_type._tensor_type_stack[-1][0]
+ if device is None:
+ device = default_tensor_type._tensor_type_stack[-1][1]
+ default_tensor_type._tensor_type_stack.append((dtype, device))
+
+ # We use all 3 calls since the new apis (set_default_device, set_default_dtype)
+ # seems to be ineffective sometimes (e.g., set_default_device is ineffective to
+ # torch.Tensor calls).
+ torch.set_default_tensor_type(default_tensor_type.get_tensor_type(dtype, device))
+ torch.set_default_device(device)
+ torch.set_default_dtype(dtype)
+
+ def __exit__(
+ self,
+ exc_type: Optional[type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ default_tensor_type._tensor_type_stack.pop()
+ dtype, device = default_tensor_type._tensor_type_stack[-1]
+
+ torch.set_default_tensor_type(default_tensor_type.get_tensor_type(dtype, device))
+ torch.set_default_device(device)
+ torch.set_default_dtype(dtype)
+
+ @staticmethod
+ def get_tensor_type(dtype: torch.dtype, device: str) -> Any:
+ return {
+ (torch.float, "cpu"): torch.FloatTensor,
+ (torch.bfloat16, "cpu"): torch.BFloat16Tensor,
+ (torch.half, "cpu"): torch.HalfTensor,
+ (torch.float, "cuda"): torch.cuda.FloatTensor,
+ (torch.bfloat16, "cuda"): torch.cuda.BFloat16Tensor,
+ (torch.half, "cuda"): torch.cuda.HalfTensor,
+ }[(dtype, device)]
+
diff --git a/util/pos_embed.py b/util/pos_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..1924913c1ffe7c73b889a4d3bad586ee8b3d2d7d
--- /dev/null
+++ b/util/pos_embed.py
@@ -0,0 +1,113 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# Position embedding utils
+# --------------------------------------------------------
+
+import numpy as np
+import torch
+
+# --------------------------------------------------------
+# 2D sine-cosine position embedding
+# References:
+# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
+# MoCo v3: https://github.com/facebookresearch/moco-v3
+# --------------------------------------------------------
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size, grid_size])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token:
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float)
+ omega /= embed_dim / 2.
+ omega = 1. / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+# --------------------------------------------------------
+# Interpolate position embeddings for high-resolution
+# References:
+# DeiT: https://github.com/facebookresearch/deit
+# --------------------------------------------------------
+def interpolate_pos_embed(model, checkpoint_model):
+ if 'pos_embed' in checkpoint_model:
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = model.patch_embed.num_patches
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
+ # height (== width) for the checkpoint position embedding
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int(num_patches ** 0.5)
+ # class_token and dist_token are kept unchanged
+ if orig_size != new_size:
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ checkpoint_model['pos_embed'] = new_pos_embed
+
+
+def interpolate_pos_embed_online(
+ pos_embed, orig_size, new_size, num_extra_tokens: int
+):
+ # [257, 1024]
+ extra_tokens = pos_embed[:num_extra_tokens]
+ pos_tokens = pos_embed[num_extra_tokens:]
+ embedding_size = pos_tokens.shape[1]
+ pos_tokens = pos_tokens.reshape(
+ -1, orig_size[0], orig_size[1], embedding_size
+ ).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens, size=new_size, mode="bicubic", align_corners=False,
+ )
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=0)
+ return new_pos_embed
\ No newline at end of file