File size: 5,536 Bytes
744eb4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import torch
import numpy as np
import pickle
from torch.utils.data import Dataset
from pointllm.utils import *
from pointllm.data.utils import *

class ModelNet(Dataset):
    def __init__(self, config_path, split, subset_nums=-1, use_color=False):
        """
        Args:
            data_args:
                split: train or test
        """
        super(ModelNet, self).__init__()

        if config_path is None:
            # * use the default config file in the same dir
            config_path = os.path.join(os.path.dirname(__file__), "modelnet_config", "ModelNet40.yaml")

        config = cfg_from_yaml_file(config_path)
        # * check data path
        self.root = config["DATA_PATH"]
    
        if not os.path.exists(self.root):
            print(f"Data path {self.root} does not exist. Please check your data path.")
            exit()

        self.npoints = config.npoints
        self.num_category = config.NUM_CATEGORY # * should be 40
        self.random_sample = config.random_sampling
        self.use_height = config.use_height
        self.use_normals = config.USE_NORMALS
        self.subset_nums = subset_nums
        self.normalize_pc = True
        self.use_color = use_color

        if self.use_height or self.use_normals:
            print(f"Warning: Usually we don't use height or normals for shapenet but use_height: {self.use_height} and \
                  use_normals: {self.use_normals}.")

        self.split = split
        assert (self.split == 'train' or self.split == 'test')

        self.catfile = os.path.join(os.path.dirname(__file__), "modelnet_config", 'modelnet40_shape_names_modified.txt')

        # "tv_stand" -> "tv stand"
        self.categories = [line.rstrip() for line in open(self.catfile)] # * list of category names

        self.save_path = os.path.join(self.root,
                                    'modelnet%d_%s_%dpts_fps.dat' % (self.num_category, self.split, self.npoints))

        print('Load processed data from %s...' % self.save_path)
        with open(self.save_path, 'rb') as f:
            self.list_of_points, self.list_of_labels = pickle.load(f) # * ndarray of N, C: (8192, 6) (xyz and normals)
        
        if self.subset_nums > 0:
            # * set random seed
            import random
            random.seed(0)
            # * random choose subset_nums
            idxs = random.sample(range(len(self.list_of_labels)), self.subset_nums)
            self.list_of_labels = [self.list_of_labels[idx] for idx in idxs]
            self.list_of_points = [self.list_of_points[idx] for idx in idxs]

        # * print len
        print(f"Load {len(self.list_of_points)} data from {self.save_path}.")

    def __len__(self):
        return len(self.list_of_labels)

    def _get_item(self, index):
        point_set, label = self.list_of_points[index], self.list_of_labels[index]

        if  self.npoints < point_set.shape[0]:
            if self.random_sample:
                # * random sample 
                point_set = point_set[np.random.choice(point_set.shape[0], self.npoints, replace=False)]
            else:
                point_set = farthest_point_sample(point_set, self.npoints)

        point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
        if not self.use_normals:
            point_set = point_set[:, 0:3]

        if self.use_height:
            self.gravity_dim = 1
            height_array = point_set[:, self.gravity_dim:self.gravity_dim + 1] - point_set[:,
                                                                            self.gravity_dim:self.gravity_dim + 1].min()
            point_set = np.concatenate((point_set, height_array), axis=1)
        
        point_set = np.concatenate((point_set, np.zeros_like(point_set)), axis=-1) if self.use_color else point_set

        return point_set, label.item() # * ndarray, int
    
    def pc_norm(self, pc):
        """ pc: NxC, return NxC """
        xyz = pc[:, :3]
        other_feature = pc[:, 3:]

        centroid = np.mean(xyz, axis=0)
        xyz = xyz - centroid
        m = np.max(np.sqrt(np.sum(xyz ** 2, axis=1)))
        xyz = xyz / m

        pc = np.concatenate((xyz, other_feature), axis=1)
        return pc

    def __getitem__(self, index):
        points, label = self._get_item(index)
        pt_idxs = np.arange(0, points.shape[0])  # 2048
        if self.split == 'train':
            np.random.shuffle(pt_idxs)
        current_points = points[pt_idxs].copy()

        if self.normalize_pc:
            # * modelnet point cloud is already normalized
            current_points = self.pc_norm(current_points)

        current_points = torch.from_numpy(current_points).float() # * N, C tensors
        label_name = self.categories[int(label)]

        data_dict = {
            "indice": index, # * int
            "point_clouds": current_points, # * tensor of N, C
            "labels": label, # * int
            "label_names": label_name # * str
        }

        return data_dict
    
if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description='ModelNet Dataset')

    parser.add_argument("--config_path", type=str, default=None, help="config file path.")
    parser.add_argument("--split", type=str, default="test", help="train or test.")
    parser.add_argument("--subset_nums", type=int, default=200)

    args = parser.parse_args()

    dataset = ModelNet(config_path=args.config_path, split=args.split, subset_nums=args.subset_nums)

    # * get the first item
    print(dataset[0])