File size: 7,818 Bytes
bd27f44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
"""
Preprocessing Script for ScanNet 20/200

Author: Xiaoyang Wu ([email protected])
Please cite our work if the code is helpful to you.
"""

import warnings

import torch

warnings.filterwarnings("ignore", category=DeprecationWarning)

import sys
import os
import argparse
import glob
import json
import plyfile
import numpy as np
import pandas as pd
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor
from itertools import repeat

# Load external constants
from meta_data.scannet200_constants import VALID_CLASS_IDS_200, VALID_CLASS_IDS_20

CLOUD_FILE_PFIX = '_vh_clean_2'
SEGMENTS_FILE_PFIX = '.0.010000.segs.json'
AGGREGATIONS_FILE_PFIX = '.aggregation.json'
CLASS_IDS200 = VALID_CLASS_IDS_200
CLASS_IDS20 = VALID_CLASS_IDS_20
IGNORE_INDEX = -1


def read_plymesh(filepath):
    """Read ply file and return it as numpy array. Returns None if emtpy."""
    with open(filepath, 'rb') as f:
        plydata = plyfile.PlyData.read(f)
    if plydata.elements:
        vertices = pd.DataFrame(plydata['vertex'].data).values
        faces = np.stack(plydata['face'].data['vertex_indices'], axis=0)
        return vertices, faces


# Map the raw category id to the point cloud
def point_indices_from_group(seg_indices, group, labels_pd):
    group_segments = np.array(group['segments'])
    label = group['label']

    # Map the category name to id
    label_id20 = labels_pd[labels_pd['raw_category'] == label]['nyu40id']
    label_id20 = int(label_id20.iloc[0]) if len(label_id20) > 0 else 0
    label_id200 = labels_pd[labels_pd['raw_category'] == label]['id']
    label_id200 = int(label_id200.iloc[0]) if len(label_id200) > 0 else 0

    # Only store for the valid categories
    if label_id20 in CLASS_IDS20:
        label_id20 = CLASS_IDS20.index(label_id20)
    else:
        label_id20 = IGNORE_INDEX

    if label_id200 in CLASS_IDS200:
        label_id200 = CLASS_IDS200.index(label_id200)
    else:
        label_id200 = IGNORE_INDEX

    # get points, where segment indices (points labelled with segment ids) are in the group segment list
    point_idx = np.where(np.isin(seg_indices, group_segments))[0]
    return point_idx, label_id20, label_id200


def face_normal(vertex, face):
    v01 = vertex[face[:, 1]] - vertex[face[:, 0]]
    v02 = vertex[face[:, 2]] - vertex[face[:, 0]]
    vec = np.cross(v01, v02)
    length = np.sqrt(np.sum(vec ** 2, axis=1, keepdims=True)) + 1.0e-8
    nf = vec / length
    area = length * 0.5
    return nf, area


def vertex_normal(vertex, face):
    nf, area = face_normal(vertex, face)
    nf = nf * area

    nv = np.zeros_like(vertex)
    for i in range(face.shape[0]):
        nv[face[i]] += nf[i]

    length = np.sqrt(np.sum(nv ** 2, axis=1, keepdims=True)) + 1.0e-8
    nv = nv / length
    return nv


def handle_process(scene_path, output_path, labels_pd, train_scenes, val_scenes, parse_normals=True):
    scene_id = os.path.basename(scene_path)
    mesh_path = os.path.join(scene_path, f'{scene_id}{CLOUD_FILE_PFIX}.ply')
    segments_file = os.path.join(scene_path, f'{scene_id}{CLOUD_FILE_PFIX}{SEGMENTS_FILE_PFIX}')
    aggregations_file = os.path.join(scene_path, f'{scene_id}{AGGREGATIONS_FILE_PFIX}')
    info_file = os.path.join(scene_path, f'{scene_id}.txt')

    if scene_id in train_scenes:
        output_file = os.path.join(output_path, 'train', f'{scene_id}.pth')
        split_name = 'train'
    elif scene_id in val_scenes:
        output_file = os.path.join(output_path, 'val', f'{scene_id}.pth')
        split_name = 'val'
    else:
        output_file = os.path.join(output_path, 'test', f'{scene_id}.pth')
        split_name = 'test'

    print(f'Processing: {scene_id} in {split_name}')

    vertices, faces = read_plymesh(mesh_path)
    coords = vertices[:, :3]
    colors = vertices[:, 3:6]
    save_dict = dict(coord=coords, color=colors, scene_id=scene_id)

    # # Rotating the mesh to axis aligned
    # info_dict = {}
    # with open(info_file) as f:
    #     for line in f:
    #         (key, val) = line.split(" = ")
    #         info_dict[key] = np.fromstring(val, sep=' ')
    #
    # if 'axisAlignment' not in info_dict:
    #     rot_matrix = np.identity(4)
    # else:
    #     rot_matrix = info_dict['axisAlignment'].reshape(4, 4)
    # r_coords = coords.transpose()
    # r_coords = np.append(r_coords, np.ones((1, r_coords.shape[1])), axis=0)
    # r_coords = np.dot(rot_matrix, r_coords)
    # coords = r_coords

    # Parse Normals
    if parse_normals:
        save_dict["normal"] = vertex_normal(coords, faces)

    # Load segments file
    if split_name != "test":
        with open(segments_file) as f:
            segments = json.load(f)
            seg_indices = np.array(segments['segIndices'])

        # Load Aggregations file
        with open(aggregations_file) as f:
            aggregation = json.load(f)
            seg_groups = np.array(aggregation['segGroups'])

        # Generate new labels
        semantic_gt20 = np.ones((vertices.shape[0])) * IGNORE_INDEX
        semantic_gt200 = np.ones((vertices.shape[0])) * IGNORE_INDEX
        instance_ids = np.ones((vertices.shape[0])) * IGNORE_INDEX
        for group in seg_groups:
            point_idx, label_id20, label_id200 = \
                point_indices_from_group(seg_indices, group, labels_pd)

            semantic_gt20[point_idx] = label_id20
            semantic_gt200[point_idx] = label_id200
            instance_ids[point_idx] = group['id']

        semantic_gt20 = semantic_gt20.astype(int)
        semantic_gt200 = semantic_gt200.astype(int)
        instance_ids = instance_ids.astype(int)

        save_dict["semantic_gt20"] = semantic_gt20
        save_dict["semantic_gt200"] = semantic_gt200
        save_dict["instance_gt"] = instance_ids

        # Concatenate with original cloud
        processed_vertices = np.hstack((semantic_gt200, instance_ids))

        if np.any(np.isnan(processed_vertices)) or not np.all(np.isfinite(processed_vertices)):
            raise ValueError(f'Find NaN in Scene: {scene_id}')

    # Save processed data
    torch.save(save_dict, output_file)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_root', required=True, help='Path to the ScanNet dataset containing scene folders')
    parser.add_argument('--output_root', required=True, help='Output path where train/val folders will be located')
    parser.add_argument('--parse_normals', default=True, type=bool, help='Whether parse point normals')
    config = parser.parse_args()

    # Load label map
    labels_pd = pd.read_csv('scannet-preprocess/meta_data/scannetv2-labels.combined.tsv',
                            sep='\t', header=0)

    # Load train/val splits
    with open('scannet-preprocess/meta_data/scannetv2_train.txt') as train_file:
        train_scenes = train_file.read().splitlines()
    with open('scannet-preprocess/meta_data/scannetv2_val.txt') as val_file:
        val_scenes = val_file.read().splitlines()

    # Create output directories
    train_output_dir = os.path.join(config.output_root, 'train')
    os.makedirs(train_output_dir, exist_ok=True)
    val_output_dir = os.path.join(config.output_root, 'val')
    os.makedirs(val_output_dir, exist_ok=True)
    test_output_dir = os.path.join(config.output_root, 'test')
    os.makedirs(test_output_dir, exist_ok=True)

    # Load scene paths
    scene_paths = sorted(glob.glob(config.dataset_root + '/scans*/scene*'))

    # Preprocess data.
    print('Processing scenes...')
    pool = ProcessPoolExecutor(max_workers=mp.cpu_count())
    # pool = ProcessPoolExecutor(max_workers=1)
    _ = list(pool.map(handle_process, scene_paths, repeat(config.output_root), repeat(labels_pd), repeat(train_scenes),
                      repeat(val_scenes), repeat(config.parse_normals)))