|
import os |
|
import torch |
|
import numpy as np |
|
import pickle |
|
from torch.utils.data import Dataset |
|
from pointllm.utils import * |
|
from pointllm.data.utils import * |
|
|
|
class ModelNet(Dataset): |
|
def __init__(self, config_path, split, subset_nums=-1, use_color=False): |
|
""" |
|
Args: |
|
data_args: |
|
split: train or test |
|
""" |
|
super(ModelNet, self).__init__() |
|
|
|
if config_path is None: |
|
|
|
config_path = os.path.join(os.path.dirname(__file__), "modelnet_config", "ModelNet40.yaml") |
|
|
|
config = cfg_from_yaml_file(config_path) |
|
|
|
self.root = config["DATA_PATH"] |
|
|
|
if not os.path.exists(self.root): |
|
print(f"Data path {self.root} does not exist. Please check your data path.") |
|
exit() |
|
|
|
self.npoints = config.npoints |
|
self.num_category = config.NUM_CATEGORY |
|
self.random_sample = config.random_sampling |
|
self.use_height = config.use_height |
|
self.use_normals = config.USE_NORMALS |
|
self.subset_nums = subset_nums |
|
self.normalize_pc = True |
|
self.use_color = use_color |
|
|
|
if self.use_height or self.use_normals: |
|
print(f"Warning: Usually we don't use height or normals for shapenet but use_height: {self.use_height} and \ |
|
use_normals: {self.use_normals}.") |
|
|
|
self.split = split |
|
assert (self.split == 'train' or self.split == 'test') |
|
|
|
self.catfile = os.path.join(os.path.dirname(__file__), "modelnet_config", 'modelnet40_shape_names_modified.txt') |
|
|
|
|
|
self.categories = [line.rstrip() for line in open(self.catfile)] |
|
|
|
self.save_path = os.path.join(self.root, |
|
'modelnet%d_%s_%dpts_fps.dat' % (self.num_category, self.split, self.npoints)) |
|
|
|
print('Load processed data from %s...' % self.save_path) |
|
with open(self.save_path, 'rb') as f: |
|
self.list_of_points, self.list_of_labels = pickle.load(f) |
|
|
|
if self.subset_nums > 0: |
|
|
|
import random |
|
random.seed(0) |
|
|
|
idxs = random.sample(range(len(self.list_of_labels)), self.subset_nums) |
|
self.list_of_labels = [self.list_of_labels[idx] for idx in idxs] |
|
self.list_of_points = [self.list_of_points[idx] for idx in idxs] |
|
|
|
|
|
print(f"Load {len(self.list_of_points)} data from {self.save_path}.") |
|
|
|
def __len__(self): |
|
return len(self.list_of_labels) |
|
|
|
def _get_item(self, index): |
|
point_set, label = self.list_of_points[index], self.list_of_labels[index] |
|
|
|
if self.npoints < point_set.shape[0]: |
|
if self.random_sample: |
|
|
|
point_set = point_set[np.random.choice(point_set.shape[0], self.npoints, replace=False)] |
|
else: |
|
point_set = farthest_point_sample(point_set, self.npoints) |
|
|
|
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3]) |
|
if not self.use_normals: |
|
point_set = point_set[:, 0:3] |
|
|
|
if self.use_height: |
|
self.gravity_dim = 1 |
|
height_array = point_set[:, self.gravity_dim:self.gravity_dim + 1] - point_set[:, |
|
self.gravity_dim:self.gravity_dim + 1].min() |
|
point_set = np.concatenate((point_set, height_array), axis=1) |
|
|
|
point_set = np.concatenate((point_set, np.zeros_like(point_set)), axis=-1) if self.use_color else point_set |
|
|
|
return point_set, label.item() |
|
|
|
def pc_norm(self, pc): |
|
""" pc: NxC, return NxC """ |
|
xyz = pc[:, :3] |
|
other_feature = pc[:, 3:] |
|
|
|
centroid = np.mean(xyz, axis=0) |
|
xyz = xyz - centroid |
|
m = np.max(np.sqrt(np.sum(xyz ** 2, axis=1))) |
|
xyz = xyz / m |
|
|
|
pc = np.concatenate((xyz, other_feature), axis=1) |
|
return pc |
|
|
|
def __getitem__(self, index): |
|
points, label = self._get_item(index) |
|
pt_idxs = np.arange(0, points.shape[0]) |
|
if self.split == 'train': |
|
np.random.shuffle(pt_idxs) |
|
current_points = points[pt_idxs].copy() |
|
|
|
if self.normalize_pc: |
|
|
|
current_points = self.pc_norm(current_points) |
|
|
|
current_points = torch.from_numpy(current_points).float() |
|
label_name = self.categories[int(label)] |
|
|
|
data_dict = { |
|
"indice": index, |
|
"point_clouds": current_points, |
|
"labels": label, |
|
"label_names": label_name |
|
} |
|
|
|
return data_dict |
|
|
|
if __name__ == '__main__': |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser(description='ModelNet Dataset') |
|
|
|
parser.add_argument("--config_path", type=str, default=None, help="config file path.") |
|
parser.add_argument("--split", type=str, default="test", help="train or test.") |
|
parser.add_argument("--subset_nums", type=int, default=200) |
|
|
|
args = parser.parse_args() |
|
|
|
dataset = ModelNet(config_path=args.config_path, split=args.split, subset_nums=args.subset_nums) |
|
|
|
|
|
print(dataset[0]) |