Jingkang Yang
first commit
bd27f44
raw
history blame
7.82 kB
"""
Preprocessing Script for ScanNet 20/200
Author: Xiaoyang Wu ([email protected])
Please cite our work if the code is helpful to you.
"""
import warnings
import torch
warnings.filterwarnings("ignore", category=DeprecationWarning)
import sys
import os
import argparse
import glob
import json
import plyfile
import numpy as np
import pandas as pd
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor
from itertools import repeat
# Load external constants
from meta_data.scannet200_constants import VALID_CLASS_IDS_200, VALID_CLASS_IDS_20
CLOUD_FILE_PFIX = '_vh_clean_2'
SEGMENTS_FILE_PFIX = '.0.010000.segs.json'
AGGREGATIONS_FILE_PFIX = '.aggregation.json'
CLASS_IDS200 = VALID_CLASS_IDS_200
CLASS_IDS20 = VALID_CLASS_IDS_20
IGNORE_INDEX = -1
def read_plymesh(filepath):
"""Read ply file and return it as numpy array. Returns None if emtpy."""
with open(filepath, 'rb') as f:
plydata = plyfile.PlyData.read(f)
if plydata.elements:
vertices = pd.DataFrame(plydata['vertex'].data).values
faces = np.stack(plydata['face'].data['vertex_indices'], axis=0)
return vertices, faces
# Map the raw category id to the point cloud
def point_indices_from_group(seg_indices, group, labels_pd):
group_segments = np.array(group['segments'])
label = group['label']
# Map the category name to id
label_id20 = labels_pd[labels_pd['raw_category'] == label]['nyu40id']
label_id20 = int(label_id20.iloc[0]) if len(label_id20) > 0 else 0
label_id200 = labels_pd[labels_pd['raw_category'] == label]['id']
label_id200 = int(label_id200.iloc[0]) if len(label_id200) > 0 else 0
# Only store for the valid categories
if label_id20 in CLASS_IDS20:
label_id20 = CLASS_IDS20.index(label_id20)
else:
label_id20 = IGNORE_INDEX
if label_id200 in CLASS_IDS200:
label_id200 = CLASS_IDS200.index(label_id200)
else:
label_id200 = IGNORE_INDEX
# get points, where segment indices (points labelled with segment ids) are in the group segment list
point_idx = np.where(np.isin(seg_indices, group_segments))[0]
return point_idx, label_id20, label_id200
def face_normal(vertex, face):
v01 = vertex[face[:, 1]] - vertex[face[:, 0]]
v02 = vertex[face[:, 2]] - vertex[face[:, 0]]
vec = np.cross(v01, v02)
length = np.sqrt(np.sum(vec ** 2, axis=1, keepdims=True)) + 1.0e-8
nf = vec / length
area = length * 0.5
return nf, area
def vertex_normal(vertex, face):
nf, area = face_normal(vertex, face)
nf = nf * area
nv = np.zeros_like(vertex)
for i in range(face.shape[0]):
nv[face[i]] += nf[i]
length = np.sqrt(np.sum(nv ** 2, axis=1, keepdims=True)) + 1.0e-8
nv = nv / length
return nv
def handle_process(scene_path, output_path, labels_pd, train_scenes, val_scenes, parse_normals=True):
scene_id = os.path.basename(scene_path)
mesh_path = os.path.join(scene_path, f'{scene_id}{CLOUD_FILE_PFIX}.ply')
segments_file = os.path.join(scene_path, f'{scene_id}{CLOUD_FILE_PFIX}{SEGMENTS_FILE_PFIX}')
aggregations_file = os.path.join(scene_path, f'{scene_id}{AGGREGATIONS_FILE_PFIX}')
info_file = os.path.join(scene_path, f'{scene_id}.txt')
if scene_id in train_scenes:
output_file = os.path.join(output_path, 'train', f'{scene_id}.pth')
split_name = 'train'
elif scene_id in val_scenes:
output_file = os.path.join(output_path, 'val', f'{scene_id}.pth')
split_name = 'val'
else:
output_file = os.path.join(output_path, 'test', f'{scene_id}.pth')
split_name = 'test'
print(f'Processing: {scene_id} in {split_name}')
vertices, faces = read_plymesh(mesh_path)
coords = vertices[:, :3]
colors = vertices[:, 3:6]
save_dict = dict(coord=coords, color=colors, scene_id=scene_id)
# # Rotating the mesh to axis aligned
# info_dict = {}
# with open(info_file) as f:
# for line in f:
# (key, val) = line.split(" = ")
# info_dict[key] = np.fromstring(val, sep=' ')
#
# if 'axisAlignment' not in info_dict:
# rot_matrix = np.identity(4)
# else:
# rot_matrix = info_dict['axisAlignment'].reshape(4, 4)
# r_coords = coords.transpose()
# r_coords = np.append(r_coords, np.ones((1, r_coords.shape[1])), axis=0)
# r_coords = np.dot(rot_matrix, r_coords)
# coords = r_coords
# Parse Normals
if parse_normals:
save_dict["normal"] = vertex_normal(coords, faces)
# Load segments file
if split_name != "test":
with open(segments_file) as f:
segments = json.load(f)
seg_indices = np.array(segments['segIndices'])
# Load Aggregations file
with open(aggregations_file) as f:
aggregation = json.load(f)
seg_groups = np.array(aggregation['segGroups'])
# Generate new labels
semantic_gt20 = np.ones((vertices.shape[0])) * IGNORE_INDEX
semantic_gt200 = np.ones((vertices.shape[0])) * IGNORE_INDEX
instance_ids = np.ones((vertices.shape[0])) * IGNORE_INDEX
for group in seg_groups:
point_idx, label_id20, label_id200 = \
point_indices_from_group(seg_indices, group, labels_pd)
semantic_gt20[point_idx] = label_id20
semantic_gt200[point_idx] = label_id200
instance_ids[point_idx] = group['id']
semantic_gt20 = semantic_gt20.astype(int)
semantic_gt200 = semantic_gt200.astype(int)
instance_ids = instance_ids.astype(int)
save_dict["semantic_gt20"] = semantic_gt20
save_dict["semantic_gt200"] = semantic_gt200
save_dict["instance_gt"] = instance_ids
# Concatenate with original cloud
processed_vertices = np.hstack((semantic_gt200, instance_ids))
if np.any(np.isnan(processed_vertices)) or not np.all(np.isfinite(processed_vertices)):
raise ValueError(f'Find NaN in Scene: {scene_id}')
# Save processed data
torch.save(save_dict, output_file)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_root', required=True, help='Path to the ScanNet dataset containing scene folders')
parser.add_argument('--output_root', required=True, help='Output path where train/val folders will be located')
parser.add_argument('--parse_normals', default=True, type=bool, help='Whether parse point normals')
config = parser.parse_args()
# Load label map
labels_pd = pd.read_csv('scannet-preprocess/meta_data/scannetv2-labels.combined.tsv',
sep='\t', header=0)
# Load train/val splits
with open('scannet-preprocess/meta_data/scannetv2_train.txt') as train_file:
train_scenes = train_file.read().splitlines()
with open('scannet-preprocess/meta_data/scannetv2_val.txt') as val_file:
val_scenes = val_file.read().splitlines()
# Create output directories
train_output_dir = os.path.join(config.output_root, 'train')
os.makedirs(train_output_dir, exist_ok=True)
val_output_dir = os.path.join(config.output_root, 'val')
os.makedirs(val_output_dir, exist_ok=True)
test_output_dir = os.path.join(config.output_root, 'test')
os.makedirs(test_output_dir, exist_ok=True)
# Load scene paths
scene_paths = sorted(glob.glob(config.dataset_root + '/scans*/scene*'))
# Preprocess data.
print('Processing scenes...')
pool = ProcessPoolExecutor(max_workers=mp.cpu_count())
# pool = ProcessPoolExecutor(max_workers=1)
_ = list(pool.map(handle_process, scene_paths, repeat(config.output_root), repeat(labels_pd), repeat(train_scenes),
repeat(val_scenes), repeat(config.parse_normals)))