Spaces:
Runtime error
Runtime error
File size: 7,818 Bytes
1a1c3c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
"""
Preprocessing Script for ScanNet 20/200
Author: Xiaoyang Wu ([email protected])
Please cite our work if the code is helpful to you.
"""
import warnings
import torch
warnings.filterwarnings("ignore", category=DeprecationWarning)
import sys
import os
import argparse
import glob
import json
import plyfile
import numpy as np
import pandas as pd
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor
from itertools import repeat
# Load external constants
from meta_data.scannet200_constants import VALID_CLASS_IDS_200, VALID_CLASS_IDS_20
CLOUD_FILE_PFIX = '_vh_clean_2'
SEGMENTS_FILE_PFIX = '.0.010000.segs.json'
AGGREGATIONS_FILE_PFIX = '.aggregation.json'
CLASS_IDS200 = VALID_CLASS_IDS_200
CLASS_IDS20 = VALID_CLASS_IDS_20
IGNORE_INDEX = -1
def read_plymesh(filepath):
"""Read ply file and return it as numpy array. Returns None if emtpy."""
with open(filepath, 'rb') as f:
plydata = plyfile.PlyData.read(f)
if plydata.elements:
vertices = pd.DataFrame(plydata['vertex'].data).values
faces = np.stack(plydata['face'].data['vertex_indices'], axis=0)
return vertices, faces
# Map the raw category id to the point cloud
def point_indices_from_group(seg_indices, group, labels_pd):
group_segments = np.array(group['segments'])
label = group['label']
# Map the category name to id
label_id20 = labels_pd[labels_pd['raw_category'] == label]['nyu40id']
label_id20 = int(label_id20.iloc[0]) if len(label_id20) > 0 else 0
label_id200 = labels_pd[labels_pd['raw_category'] == label]['id']
label_id200 = int(label_id200.iloc[0]) if len(label_id200) > 0 else 0
# Only store for the valid categories
if label_id20 in CLASS_IDS20:
label_id20 = CLASS_IDS20.index(label_id20)
else:
label_id20 = IGNORE_INDEX
if label_id200 in CLASS_IDS200:
label_id200 = CLASS_IDS200.index(label_id200)
else:
label_id200 = IGNORE_INDEX
# get points, where segment indices (points labelled with segment ids) are in the group segment list
point_idx = np.where(np.isin(seg_indices, group_segments))[0]
return point_idx, label_id20, label_id200
def face_normal(vertex, face):
v01 = vertex[face[:, 1]] - vertex[face[:, 0]]
v02 = vertex[face[:, 2]] - vertex[face[:, 0]]
vec = np.cross(v01, v02)
length = np.sqrt(np.sum(vec ** 2, axis=1, keepdims=True)) + 1.0e-8
nf = vec / length
area = length * 0.5
return nf, area
def vertex_normal(vertex, face):
nf, area = face_normal(vertex, face)
nf = nf * area
nv = np.zeros_like(vertex)
for i in range(face.shape[0]):
nv[face[i]] += nf[i]
length = np.sqrt(np.sum(nv ** 2, axis=1, keepdims=True)) + 1.0e-8
nv = nv / length
return nv
def handle_process(scene_path, output_path, labels_pd, train_scenes, val_scenes, parse_normals=True):
scene_id = os.path.basename(scene_path)
mesh_path = os.path.join(scene_path, f'{scene_id}{CLOUD_FILE_PFIX}.ply')
segments_file = os.path.join(scene_path, f'{scene_id}{CLOUD_FILE_PFIX}{SEGMENTS_FILE_PFIX}')
aggregations_file = os.path.join(scene_path, f'{scene_id}{AGGREGATIONS_FILE_PFIX}')
info_file = os.path.join(scene_path, f'{scene_id}.txt')
if scene_id in train_scenes:
output_file = os.path.join(output_path, 'train', f'{scene_id}.pth')
split_name = 'train'
elif scene_id in val_scenes:
output_file = os.path.join(output_path, 'val', f'{scene_id}.pth')
split_name = 'val'
else:
output_file = os.path.join(output_path, 'test', f'{scene_id}.pth')
split_name = 'test'
print(f'Processing: {scene_id} in {split_name}')
vertices, faces = read_plymesh(mesh_path)
coords = vertices[:, :3]
colors = vertices[:, 3:6]
save_dict = dict(coord=coords, color=colors, scene_id=scene_id)
# # Rotating the mesh to axis aligned
# info_dict = {}
# with open(info_file) as f:
# for line in f:
# (key, val) = line.split(" = ")
# info_dict[key] = np.fromstring(val, sep=' ')
#
# if 'axisAlignment' not in info_dict:
# rot_matrix = np.identity(4)
# else:
# rot_matrix = info_dict['axisAlignment'].reshape(4, 4)
# r_coords = coords.transpose()
# r_coords = np.append(r_coords, np.ones((1, r_coords.shape[1])), axis=0)
# r_coords = np.dot(rot_matrix, r_coords)
# coords = r_coords
# Parse Normals
if parse_normals:
save_dict["normal"] = vertex_normal(coords, faces)
# Load segments file
if split_name != "test":
with open(segments_file) as f:
segments = json.load(f)
seg_indices = np.array(segments['segIndices'])
# Load Aggregations file
with open(aggregations_file) as f:
aggregation = json.load(f)
seg_groups = np.array(aggregation['segGroups'])
# Generate new labels
semantic_gt20 = np.ones((vertices.shape[0])) * IGNORE_INDEX
semantic_gt200 = np.ones((vertices.shape[0])) * IGNORE_INDEX
instance_ids = np.ones((vertices.shape[0])) * IGNORE_INDEX
for group in seg_groups:
point_idx, label_id20, label_id200 = \
point_indices_from_group(seg_indices, group, labels_pd)
semantic_gt20[point_idx] = label_id20
semantic_gt200[point_idx] = label_id200
instance_ids[point_idx] = group['id']
semantic_gt20 = semantic_gt20.astype(int)
semantic_gt200 = semantic_gt200.astype(int)
instance_ids = instance_ids.astype(int)
save_dict["semantic_gt20"] = semantic_gt20
save_dict["semantic_gt200"] = semantic_gt200
save_dict["instance_gt"] = instance_ids
# Concatenate with original cloud
processed_vertices = np.hstack((semantic_gt200, instance_ids))
if np.any(np.isnan(processed_vertices)) or not np.all(np.isfinite(processed_vertices)):
raise ValueError(f'Find NaN in Scene: {scene_id}')
# Save processed data
torch.save(save_dict, output_file)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_root', required=True, help='Path to the ScanNet dataset containing scene folders')
parser.add_argument('--output_root', required=True, help='Output path where train/val folders will be located')
parser.add_argument('--parse_normals', default=True, type=bool, help='Whether parse point normals')
config = parser.parse_args()
# Load label map
labels_pd = pd.read_csv('scannet-preprocess/meta_data/scannetv2-labels.combined.tsv',
sep='\t', header=0)
# Load train/val splits
with open('scannet-preprocess/meta_data/scannetv2_train.txt') as train_file:
train_scenes = train_file.read().splitlines()
with open('scannet-preprocess/meta_data/scannetv2_val.txt') as val_file:
val_scenes = val_file.read().splitlines()
# Create output directories
train_output_dir = os.path.join(config.output_root, 'train')
os.makedirs(train_output_dir, exist_ok=True)
val_output_dir = os.path.join(config.output_root, 'val')
os.makedirs(val_output_dir, exist_ok=True)
test_output_dir = os.path.join(config.output_root, 'test')
os.makedirs(test_output_dir, exist_ok=True)
# Load scene paths
scene_paths = sorted(glob.glob(config.dataset_root + '/scans*/scene*'))
# Preprocess data.
print('Processing scenes...')
pool = ProcessPoolExecutor(max_workers=mp.cpu_count())
# pool = ProcessPoolExecutor(max_workers=1)
_ = list(pool.map(handle_process, scene_paths, repeat(config.output_root), repeat(labels_pd), repeat(train_scenes),
repeat(val_scenes), repeat(config.parse_normals)))
|