File size: 10,285 Bytes
744eb4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 |
import os
import json
import torch
import numpy as np
import copy
import transformers
from torch.utils.data import Dataset
from .utils import *
def make_object_point_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
"""Make dataset and collator for Joint3Ddataset with text and point cloud data."""
"""Initialize datasets."""
data_collator = DataCollatorForPointTextDataset(tokenizer=tokenizer)
if data_args.split_train_val:
print("Loading training datasets.")
train_dataset = ObjectPointCloudDataset(
split='train',
data_path=data_args.data_path,
anno_path=data_args.anno_path,
pointnum=data_args.pointnum,
conversation_types=data_args.conversation_types,
tokenizer=tokenizer,
use_color=data_args.use_color,
data_args=data_args
)
print("Done!")
if data_args.data_debug_num > 0:
print('Debug mode, using training set as val set.')
val_dataset = train_dataset
else:
# * make a val dataset
print("Loading validation datasets.")
val_dataset = ObjectPointCloudDataset(
split='val', # * load train split
data_path=data_args.data_path,
anno_path=data_args.anno_path,
pointnum=data_args.pointnum,
conversation_types=data_args.conversation_types,
tokenizer=tokenizer,
use_color=data_args.use_color,
data_args=data_args
)
return dict(train_dataset=train_dataset, eval_dataset=val_dataset, data_collator=data_collator)
else:
# * use all data as training data
train_dataset = ObjectPointCloudDataset(
split='train',
data_path=data_args.data_path,
anno_path=data_args.anno_path,
pointnum=data_args.pointnum,
conversation_types=data_args.conversation_types,
use_color=data_args.use_color,
tokenizer=tokenizer,
data_args=data_args
)
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
class ObjectPointCloudDataset(Dataset):
"""Dataset utilities for objaverse."""
def __init__(self,
data_path=None,
anno_path=None,
tokenizer=None,
pointnum=8192,
split='train',
conversation_types=None, # * default is simple_des, used for stage1 pre-train
use_color=True,
data_args=None):
"""
split: only considered when data_args.split_train_val is True.
conversation_types: tuple, used to filter the data, default is ('simple_description'), other types is:
"detailed_description", "single_round", "multi_round".
tokenizer: load point clouds only if None
"""
super(ObjectPointCloudDataset, self).__init__()
"""Initialize dataset with object point clouds and text"""
self.data_path = data_path
self.anno_path = anno_path
self.tokenizer = tokenizer
self.split = split
if conversation_types is None:
self.conversation_types = ("simple_description",)
else:
self.conversation_types = conversation_types
self.data_args = data_args
self.normalize_pc = True
self.use_color = use_color
self.pointnum = pointnum
self.point_backbone_config = data_args.point_backbone_config if data_args is not None else None
self.point_indicator = '<point>'
# Load the data list from JSON
print(f"Loading anno file from {anno_path}.")
with open(anno_path, "r") as json_file:
self.list_data_dict = json.load(json_file)
# * print the conversations_type
print(f"Using conversation_type: {self.conversation_types}")
# * print before filtering
print(f"Before filtering, the dataset size is: {len(self.list_data_dict)}.")
# * iterate the list and filter
# * these two ids have corrupted colored point files, so filter them when use_color is True
filter_ids = ['6760e543e1d645d5aaacd3803bcae524', 'b91c0711149d460a8004f9c06d3b7f38'] if self.use_color else []
# Iterate the list, filter those "conversation_type" not in self.conversation_types
self.list_data_dict = [
data for data in self.list_data_dict
if data.get('conversation_type', 'simple_description') in self.conversation_types
and data.get('object_id') not in filter_ids
]
# * print after filtering
print(f"After filtering, the dataset size is: {len(self.list_data_dict)}.")
# * print the size of different conversation_type
for conversation_type in self.conversation_types:
print(f"Number of {conversation_type}: {len([data for data in self.list_data_dict if data.get('conversation_type', 'simple_description') == conversation_type])}")
if self.data_args is not None and self.data_args.data_debug_num > 0:
self.list_data_dict = self.list_data_dict[:self.data_args.data_debug_num]
# * print all the scan_id in debug mode, not using for loop
print('Debug mode, using: ' + ' '.join([data['object_id'] for data in self.list_data_dict]))
elif self.data_args is not None and self.data_args.split_train_val:
# * split train and val with 9:1 ratios
if self.split == 'train':
self.list_data_dict = self.list_data_dict[:int(self.data_args.split_ratio * len(self.list_data_dict))]
print(f"Train set size: {len(self.list_data_dict)}")
else:
self.list_data_dict = self.list_data_dict[int(self.data_args.split_ratio * len(self.list_data_dict)):]
print(f"Val set size: {len(self.list_data_dict)}")
def _load_point_cloud(self, object_id, type='objaverse'):
if type == 'objaverse':
return self._load_objaverse_point_cloud(object_id)
def _load_objaverse_point_cloud(self, object_id):
filename = f"{object_id}_{self.pointnum}.npy"
point_cloud = np.load(os.path.join(self.data_path, filename))
if not self.use_color:
point_cloud = point_cloud[:, :3]
return point_cloud
def pc_norm(self, pc):
""" pc: NxC, return NxC """
xyz = pc[:, :3]
other_feature = pc[:, 3:]
centroid = np.mean(xyz, axis=0)
xyz = xyz - centroid
m = np.max(np.sqrt(np.sum(xyz ** 2, axis=1)))
xyz = xyz / m
pc = np.concatenate((xyz, other_feature), axis=1)
return pc
def __getitem__(self, index):
sources = self.list_data_dict[index]
if isinstance(index, int):
sources = [sources]
assert len(sources) == 1, "sources should be a list"
if self.point_indicator in sources[0]['conversations'][0]['value']:
object_id = self.list_data_dict[index]['object_id']
# Point cloud representation
point_cloud = self._load_point_cloud(object_id) # * N, C
if self.normalize_pc:
point_cloud = self.pc_norm(point_cloud) # * need to norm since point encoder is norm
if self.tokenizer is None:
data_dict = dict(
point_clouds=torch.from_numpy(point_cloud.astype(np.float32)),
object_ids=object_id
)
return data_dict
sources = preprocess_multimodal_point_cloud(
copy.deepcopy([e["conversations"] for e in sources]), self.point_backbone_config, point_indicator=self.point_indicator)
else:
sources = copy.deepcopy([e["conversations"] for e in sources])
data_dict = preprocess_v1(
sources,
self.tokenizer)
if isinstance(index, int):
data_dict = dict(input_ids=data_dict["input_ids"][0],
labels=data_dict["labels"][0])
# point exist in the data
if self.point_indicator in self.list_data_dict[index]['conversations'][0]['value']:
data_dict['point_clouds'] = torch.from_numpy(point_cloud.astype(np.float32))
return data_dict
def __len__(self):
"""Return number of utterances."""
return len(self.list_data_dict)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--data_path", default="data/objaverse_data", type=str,
help="Path to the data directory.")
parser.add_argument("--anno_path", default=None, type=str, required=True,
help="Path to the annotation file.")
parser.add_argument("--split", default='train', type=str,
help="Whether to use the train or validation dataset.")
parser.add_argument("--pointnum", default=8192, type=int,
help="Number of points in the point cloud.")
parser.add_argument("--data_debug_num", default=0, type=int,
help="Number of data to debug with.")
parser.add_argument("--split_train_val", default=False, type=bool,
help="Whether to split the dataset into training and validation.")
parser.add_argument("--split_ratio", default=0.9, type=float,
help="The ratio of training to validation data.")
parser.add_argument("--tokenizer_path", default=None, type=str, required=True,
help="Path to the tokenizer config file.")
args = parser.parse_args()
# Initialize tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer_path)
args.point_backbone_config = None
# Initialize dataset
dataset = ObjectPointCloudDataset(
data_path=args.data_path,
anno_path=args.anno_path,
pointnum=args.pointnum,
split=args.split,
tokenizer=tokenizer,
data_args=args
)
# Example usage
print(f'Dataset length: {len(dataset)}')
|