|
import os |
|
import json |
|
import torch |
|
import numpy as np |
|
|
|
import copy |
|
import transformers |
|
from torch.utils.data import Dataset |
|
|
|
from .utils import * |
|
|
|
|
|
def make_object_point_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: |
|
"""Make dataset and collator for Joint3Ddataset with text and point cloud data.""" |
|
"""Initialize datasets.""" |
|
|
|
data_collator = DataCollatorForPointTextDataset(tokenizer=tokenizer) |
|
if data_args.split_train_val: |
|
print("Loading training datasets.") |
|
train_dataset = ObjectPointCloudDataset( |
|
split='train', |
|
data_path=data_args.data_path, |
|
anno_path=data_args.anno_path, |
|
pointnum=data_args.pointnum, |
|
conversation_types=data_args.conversation_types, |
|
tokenizer=tokenizer, |
|
use_color=data_args.use_color, |
|
data_args=data_args |
|
) |
|
print("Done!") |
|
if data_args.data_debug_num > 0: |
|
print('Debug mode, using training set as val set.') |
|
val_dataset = train_dataset |
|
else: |
|
|
|
print("Loading validation datasets.") |
|
val_dataset = ObjectPointCloudDataset( |
|
split='val', |
|
data_path=data_args.data_path, |
|
anno_path=data_args.anno_path, |
|
pointnum=data_args.pointnum, |
|
conversation_types=data_args.conversation_types, |
|
tokenizer=tokenizer, |
|
use_color=data_args.use_color, |
|
data_args=data_args |
|
) |
|
return dict(train_dataset=train_dataset, eval_dataset=val_dataset, data_collator=data_collator) |
|
else: |
|
|
|
train_dataset = ObjectPointCloudDataset( |
|
split='train', |
|
data_path=data_args.data_path, |
|
anno_path=data_args.anno_path, |
|
pointnum=data_args.pointnum, |
|
conversation_types=data_args.conversation_types, |
|
use_color=data_args.use_color, |
|
tokenizer=tokenizer, |
|
data_args=data_args |
|
) |
|
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) |
|
|
|
class ObjectPointCloudDataset(Dataset): |
|
"""Dataset utilities for objaverse.""" |
|
def __init__(self, |
|
data_path=None, |
|
anno_path=None, |
|
tokenizer=None, |
|
pointnum=8192, |
|
split='train', |
|
conversation_types=None, |
|
use_color=True, |
|
data_args=None): |
|
|
|
""" |
|
split: only considered when data_args.split_train_val is True. |
|
conversation_types: tuple, used to filter the data, default is ('simple_description'), other types is: |
|
"detailed_description", "single_round", "multi_round". |
|
tokenizer: load point clouds only if None |
|
""" |
|
super(ObjectPointCloudDataset, self).__init__() |
|
|
|
"""Initialize dataset with object point clouds and text""" |
|
self.data_path = data_path |
|
self.anno_path = anno_path |
|
self.tokenizer = tokenizer |
|
self.split = split |
|
if conversation_types is None: |
|
self.conversation_types = ("simple_description",) |
|
else: |
|
self.conversation_types = conversation_types |
|
|
|
self.data_args = data_args |
|
self.normalize_pc = True |
|
self.use_color = use_color |
|
|
|
self.pointnum = pointnum |
|
self.point_backbone_config = data_args.point_backbone_config if data_args is not None else None |
|
self.point_indicator = '<point>' |
|
|
|
|
|
print(f"Loading anno file from {anno_path}.") |
|
with open(anno_path, "r") as json_file: |
|
self.list_data_dict = json.load(json_file) |
|
|
|
|
|
print(f"Using conversation_type: {self.conversation_types}") |
|
|
|
print(f"Before filtering, the dataset size is: {len(self.list_data_dict)}.") |
|
|
|
|
|
|
|
filter_ids = ['6760e543e1d645d5aaacd3803bcae524', 'b91c0711149d460a8004f9c06d3b7f38'] if self.use_color else [] |
|
|
|
|
|
self.list_data_dict = [ |
|
data for data in self.list_data_dict |
|
if data.get('conversation_type', 'simple_description') in self.conversation_types |
|
and data.get('object_id') not in filter_ids |
|
] |
|
|
|
|
|
print(f"After filtering, the dataset size is: {len(self.list_data_dict)}.") |
|
|
|
for conversation_type in self.conversation_types: |
|
print(f"Number of {conversation_type}: {len([data for data in self.list_data_dict if data.get('conversation_type', 'simple_description') == conversation_type])}") |
|
|
|
if self.data_args is not None and self.data_args.data_debug_num > 0: |
|
self.list_data_dict = self.list_data_dict[:self.data_args.data_debug_num] |
|
|
|
print('Debug mode, using: ' + ' '.join([data['object_id'] for data in self.list_data_dict])) |
|
elif self.data_args is not None and self.data_args.split_train_val: |
|
|
|
if self.split == 'train': |
|
self.list_data_dict = self.list_data_dict[:int(self.data_args.split_ratio * len(self.list_data_dict))] |
|
print(f"Train set size: {len(self.list_data_dict)}") |
|
else: |
|
self.list_data_dict = self.list_data_dict[int(self.data_args.split_ratio * len(self.list_data_dict)):] |
|
print(f"Val set size: {len(self.list_data_dict)}") |
|
|
|
def _load_point_cloud(self, object_id, type='objaverse'): |
|
if type == 'objaverse': |
|
return self._load_objaverse_point_cloud(object_id) |
|
|
|
def _load_objaverse_point_cloud(self, object_id): |
|
filename = f"{object_id}_{self.pointnum}.npy" |
|
point_cloud = np.load(os.path.join(self.data_path, filename)) |
|
|
|
if not self.use_color: |
|
point_cloud = point_cloud[:, :3] |
|
|
|
return point_cloud |
|
|
|
def pc_norm(self, pc): |
|
""" pc: NxC, return NxC """ |
|
xyz = pc[:, :3] |
|
other_feature = pc[:, 3:] |
|
|
|
centroid = np.mean(xyz, axis=0) |
|
xyz = xyz - centroid |
|
m = np.max(np.sqrt(np.sum(xyz ** 2, axis=1))) |
|
xyz = xyz / m |
|
|
|
pc = np.concatenate((xyz, other_feature), axis=1) |
|
return pc |
|
|
|
def __getitem__(self, index): |
|
sources = self.list_data_dict[index] |
|
if isinstance(index, int): |
|
sources = [sources] |
|
assert len(sources) == 1, "sources should be a list" |
|
if self.point_indicator in sources[0]['conversations'][0]['value']: |
|
|
|
object_id = self.list_data_dict[index]['object_id'] |
|
|
|
|
|
point_cloud = self._load_point_cloud(object_id) |
|
if self.normalize_pc: |
|
point_cloud = self.pc_norm(point_cloud) |
|
|
|
if self.tokenizer is None: |
|
data_dict = dict( |
|
point_clouds=torch.from_numpy(point_cloud.astype(np.float32)), |
|
object_ids=object_id |
|
) |
|
return data_dict |
|
|
|
sources = preprocess_multimodal_point_cloud( |
|
copy.deepcopy([e["conversations"] for e in sources]), self.point_backbone_config, point_indicator=self.point_indicator) |
|
else: |
|
sources = copy.deepcopy([e["conversations"] for e in sources]) |
|
|
|
data_dict = preprocess_v1( |
|
sources, |
|
self.tokenizer) |
|
|
|
if isinstance(index, int): |
|
data_dict = dict(input_ids=data_dict["input_ids"][0], |
|
labels=data_dict["labels"][0]) |
|
|
|
|
|
if self.point_indicator in self.list_data_dict[index]['conversations'][0]['value']: |
|
data_dict['point_clouds'] = torch.from_numpy(point_cloud.astype(np.float32)) |
|
|
|
return data_dict |
|
|
|
def __len__(self): |
|
"""Return number of utterances.""" |
|
return len(self.list_data_dict) |
|
|
|
if __name__ == '__main__': |
|
import argparse |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument("--data_path", default="data/objaverse_data", type=str, |
|
help="Path to the data directory.") |
|
parser.add_argument("--anno_path", default=None, type=str, required=True, |
|
help="Path to the annotation file.") |
|
parser.add_argument("--split", default='train', type=str, |
|
help="Whether to use the train or validation dataset.") |
|
parser.add_argument("--pointnum", default=8192, type=int, |
|
help="Number of points in the point cloud.") |
|
parser.add_argument("--data_debug_num", default=0, type=int, |
|
help="Number of data to debug with.") |
|
parser.add_argument("--split_train_val", default=False, type=bool, |
|
help="Whether to split the dataset into training and validation.") |
|
parser.add_argument("--split_ratio", default=0.9, type=float, |
|
help="The ratio of training to validation data.") |
|
parser.add_argument("--tokenizer_path", default=None, type=str, required=True, |
|
help="Path to the tokenizer config file.") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer_path) |
|
|
|
args.point_backbone_config = None |
|
|
|
|
|
dataset = ObjectPointCloudDataset( |
|
data_path=args.data_path, |
|
anno_path=args.anno_path, |
|
pointnum=args.pointnum, |
|
split=args.split, |
|
tokenizer=tokenizer, |
|
data_args=args |
|
) |
|
|
|
|
|
print(f'Dataset length: {len(dataset)}') |
|
|
|
|