StableRecon / croco /datasets /habitat_sim /multiview_habitat_sim_generator.py
Stable-X's picture
Upload folder using huggingface_hub
e4bf056 verified
raw
history blame
19.3 kB
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
import os
import numpy as np
import quaternion
import habitat_sim
import json
from sklearn.neighbors import NearestNeighbors
import cv2
# OpenCV to habitat camera convention transformation
R_OPENCV2HABITAT = np.stack((habitat_sim.geo.RIGHT, -habitat_sim.geo.UP, habitat_sim.geo.FRONT), axis=0)
R_HABITAT2OPENCV = R_OPENCV2HABITAT.T
DEG2RAD = np.pi / 180
def compute_camera_intrinsics(height, width, hfov):
f = width/2 / np.tan(hfov/2 * np.pi/180)
cu, cv = width/2, height/2
return f, cu, cv
def compute_camera_pose_opencv_convention(camera_position, camera_orientation):
R_cam2world = quaternion.as_rotation_matrix(camera_orientation) @ R_OPENCV2HABITAT
t_cam2world = np.asarray(camera_position)
return R_cam2world, t_cam2world
def compute_pointmap(depthmap, hfov):
""" Compute a HxWx3 pointmap in camera frame from a HxW depth map."""
height, width = depthmap.shape
f, cu, cv = compute_camera_intrinsics(height, width, hfov)
# Cast depth map to point
z_cam = depthmap
u, v = np.meshgrid(range(width), range(height))
x_cam = (u - cu) / f * z_cam
y_cam = (v - cv) / f * z_cam
X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1)
return X_cam
def compute_pointcloud(depthmap, hfov, camera_position, camera_rotation):
"""Return a 3D point cloud corresponding to valid pixels of the depth map"""
R_cam2world, t_cam2world = compute_camera_pose_opencv_convention(camera_position, camera_rotation)
X_cam = compute_pointmap(depthmap=depthmap, hfov=hfov)
valid_mask = (X_cam[:,:,2] != 0.0)
X_cam = X_cam.reshape(-1, 3)[valid_mask.flatten()]
X_world = X_cam @ R_cam2world.T + t_cam2world.reshape(1, 3)
return X_world
def compute_pointcloud_overlaps_scikit(pointcloud1, pointcloud2, distance_threshold, compute_symmetric=False, adaptive_threshold=False):
"""
Compute 'overlapping' metrics based on a distance threshold between two point clouds.
"""
if adaptive_threshold:
distances1 = NearestNeighbors(n_neighbors=2).fit(pointcloud1).kneighbors(pointcloud1)[0][:, 1]
distances2 = NearestNeighbors(n_neighbors=2).fit(pointcloud2).kneighbors(pointcloud2)[0][:, 1]
distance_threshold = (np.mean(distances1) + np.mean(distances2)) / 2
nbrs = NearestNeighbors(n_neighbors=1, algorithm = 'kd_tree').fit(pointcloud2)
distances, indices = nbrs.kneighbors(pointcloud1)
intersection1 = np.count_nonzero(distances.flatten() < distance_threshold)
data = {"intersection1": intersection1,
"size1": len(pointcloud1)}
if compute_symmetric:
nbrs = NearestNeighbors(n_neighbors=1, algorithm = 'kd_tree').fit(pointcloud1)
distances, indices = nbrs.kneighbors(pointcloud2)
intersection2 = np.count_nonzero(distances.flatten() < distance_threshold)
data["intersection2"] = intersection2
data["size2"] = len(pointcloud2)
return data
def _append_camera_parameters(observation, hfov, camera_location, camera_rotation):
"""
Add camera parameters to the observation dictionnary produced by Habitat-Sim
In-place modifications.
"""
R_cam2world, t_cam2world = compute_camera_pose_opencv_convention(camera_location, camera_rotation)
height, width = observation['depth'].shape
f, cu, cv = compute_camera_intrinsics(height, width, hfov)
K = np.asarray([[f, 0, cu],
[0, f, cv],
[0, 0, 1.0]])
observation["camera_intrinsics"] = K
observation["t_cam2world"] = t_cam2world
observation["R_cam2world"] = R_cam2world
def look_at(eye, center, up, return_cam2world=True):
"""
Return camera pose looking at a given center point.
Analogous of gluLookAt function, using OpenCV camera convention.
"""
z = center - eye
z /= np.linalg.norm(z, axis=-1, keepdims=True)
y = -up
y = y - np.sum(y * z, axis=-1, keepdims=True) * z
y /= np.linalg.norm(y, axis=-1, keepdims=True)
x = np.cross(y, z, axis=-1)
if return_cam2world:
R = np.stack((x, y, z), axis=-1)
t = eye
else:
# World to camera transformation
# Transposed matrix
R = np.stack((x, y, z), axis=-2)
t = - np.einsum('...ij, ...j', R, eye)
return R, t
def look_at_for_habitat(eye, center, up, return_cam2world=True):
R, t = look_at(eye, center, up)
orientation = quaternion.from_rotation_matrix(R @ R_OPENCV2HABITAT.T)
return orientation, t
def generate_orientation_noise(pan_range, tilt_range, roll_range):
return (quaternion.from_rotation_vector(np.random.uniform(*pan_range) * DEG2RAD * habitat_sim.geo.UP)
* quaternion.from_rotation_vector(np.random.uniform(*tilt_range) * DEG2RAD * habitat_sim.geo.RIGHT)
* quaternion.from_rotation_vector(np.random.uniform(*roll_range) * DEG2RAD * habitat_sim.geo.FRONT))
class NoNaviguableSpaceError(RuntimeError):
def __init__(self, *args):
super().__init__(*args)
class MultiviewHabitatSimGenerator:
def __init__(self,
scene,
navmesh,
scene_dataset_config_file,
resolution = (240, 320),
views_count=2,
hfov = 60,
gpu_id = 0,
size = 10000,
minimum_covisibility = 0.5,
transform = None):
self.scene = scene
self.navmesh = navmesh
self.scene_dataset_config_file = scene_dataset_config_file
self.resolution = resolution
self.views_count = views_count
assert(self.views_count >= 1)
self.hfov = hfov
self.gpu_id = gpu_id
self.size = size
self.transform = transform
# Noise added to camera orientation
self.pan_range = (-3, 3)
self.tilt_range = (-10, 10)
self.roll_range = (-5, 5)
# Height range to sample cameras
self.height_range = (1.2, 1.8)
# Random steps between the camera views
self.random_steps_count = 5
self.random_step_variance = 2.0
# Minimum fraction of the scene which should be valid (well defined depth)
self.minimum_valid_fraction = 0.7
# Distance threshold to see to select pairs
self.distance_threshold = 0.05
# Minimum IoU of a view point cloud with respect to the reference view to be kept.
self.minimum_covisibility = minimum_covisibility
# Maximum number of retries.
self.max_attempts_count = 100
self.seed = None
self._lazy_initialization()
def _lazy_initialization(self):
# Lazy random seeding and instantiation of the simulator to deal with multiprocessing properly
if self.seed == None:
# Re-seed numpy generator
np.random.seed()
self.seed = np.random.randint(2**32-1)
sim_cfg = habitat_sim.SimulatorConfiguration()
sim_cfg.scene_id = self.scene
if self.scene_dataset_config_file is not None and self.scene_dataset_config_file != "":
sim_cfg.scene_dataset_config_file = self.scene_dataset_config_file
sim_cfg.random_seed = self.seed
sim_cfg.load_semantic_mesh = False
sim_cfg.gpu_device_id = self.gpu_id
depth_sensor_spec = habitat_sim.CameraSensorSpec()
depth_sensor_spec.uuid = "depth"
depth_sensor_spec.sensor_type = habitat_sim.SensorType.DEPTH
depth_sensor_spec.resolution = self.resolution
depth_sensor_spec.hfov = self.hfov
depth_sensor_spec.position = [0.0, 0.0, 0]
depth_sensor_spec.orientation
rgb_sensor_spec = habitat_sim.CameraSensorSpec()
rgb_sensor_spec.uuid = "color"
rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR
rgb_sensor_spec.resolution = self.resolution
rgb_sensor_spec.hfov = self.hfov
rgb_sensor_spec.position = [0.0, 0.0, 0]
agent_cfg = habitat_sim.agent.AgentConfiguration(sensor_specifications=[rgb_sensor_spec, depth_sensor_spec])
cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg])
self.sim = habitat_sim.Simulator(cfg)
if self.navmesh is not None and self.navmesh != "":
# Use pre-computed navmesh when available (usually better than those generated automatically)
self.sim.pathfinder.load_nav_mesh(self.navmesh)
if not self.sim.pathfinder.is_loaded:
# Try to compute a navmesh
navmesh_settings = habitat_sim.NavMeshSettings()
navmesh_settings.set_defaults()
self.sim.recompute_navmesh(self.sim.pathfinder, navmesh_settings, True)
# Ensure that the navmesh is not empty
if not self.sim.pathfinder.is_loaded:
raise NoNaviguableSpaceError(f"No naviguable location (scene: {self.scene} -- navmesh: {self.navmesh})")
self.agent = self.sim.initialize_agent(agent_id=0)
def close(self):
self.sim.close()
def __del__(self):
self.sim.close()
def __len__(self):
return self.size
def sample_random_viewpoint(self):
""" Sample a random viewpoint using the navmesh """
nav_point = self.sim.pathfinder.get_random_navigable_point()
# Sample a random viewpoint height
viewpoint_height = np.random.uniform(*self.height_range)
viewpoint_position = nav_point + viewpoint_height * habitat_sim.geo.UP
viewpoint_orientation = quaternion.from_rotation_vector(np.random.uniform(0, 2 * np.pi) * habitat_sim.geo.UP) * generate_orientation_noise(self.pan_range, self.tilt_range, self.roll_range)
return viewpoint_position, viewpoint_orientation, nav_point
def sample_other_random_viewpoint(self, observed_point, nav_point):
""" Sample a random viewpoint close to an existing one, using the navmesh and a reference observed point."""
other_nav_point = nav_point
walk_directions = self.random_step_variance * np.asarray([1,0,1])
for i in range(self.random_steps_count):
temp = self.sim.pathfinder.snap_point(other_nav_point + walk_directions * np.random.normal(size=3))
# Snapping may return nan when it fails
if not np.isnan(temp[0]):
other_nav_point = temp
other_viewpoint_height = np.random.uniform(*self.height_range)
other_viewpoint_position = other_nav_point + other_viewpoint_height * habitat_sim.geo.UP
# Set viewing direction towards the central point
rotation, position = look_at_for_habitat(eye=other_viewpoint_position, center=observed_point, up=habitat_sim.geo.UP, return_cam2world=True)
rotation = rotation * generate_orientation_noise(self.pan_range, self.tilt_range, self.roll_range)
return position, rotation, other_nav_point
def is_other_pointcloud_overlapping(self, ref_pointcloud, other_pointcloud):
""" Check if a viewpoint is valid and overlaps significantly with a reference one. """
# Observation
pixels_count = self.resolution[0] * self.resolution[1]
valid_fraction = len(other_pointcloud) / pixels_count
assert valid_fraction <= 1.0 and valid_fraction >= 0.0
overlap = compute_pointcloud_overlaps_scikit(ref_pointcloud, other_pointcloud, self.distance_threshold, compute_symmetric=True)
covisibility = min(overlap["intersection1"] / pixels_count, overlap["intersection2"] / pixels_count)
is_valid = (valid_fraction >= self.minimum_valid_fraction) and (covisibility >= self.minimum_covisibility)
return is_valid, valid_fraction, covisibility
def is_other_viewpoint_overlapping(self, ref_pointcloud, observation, position, rotation):
""" Check if a viewpoint is valid and overlaps significantly with a reference one. """
# Observation
other_pointcloud = compute_pointcloud(observation['depth'], self.hfov, position, rotation)
return self.is_other_pointcloud_overlapping(ref_pointcloud, other_pointcloud)
def render_viewpoint(self, viewpoint_position, viewpoint_orientation):
agent_state = habitat_sim.AgentState()
agent_state.position = viewpoint_position
agent_state.rotation = viewpoint_orientation
self.agent.set_state(agent_state)
viewpoint_observations = self.sim.get_sensor_observations(agent_ids=0)
_append_camera_parameters(viewpoint_observations, self.hfov, viewpoint_position, viewpoint_orientation)
return viewpoint_observations
def __getitem__(self, useless_idx):
ref_position, ref_orientation, nav_point = self.sample_random_viewpoint()
ref_observations = self.render_viewpoint(ref_position, ref_orientation)
# Extract point cloud
ref_pointcloud = compute_pointcloud(depthmap=ref_observations['depth'], hfov=self.hfov,
camera_position=ref_position, camera_rotation=ref_orientation)
pixels_count = self.resolution[0] * self.resolution[1]
ref_valid_fraction = len(ref_pointcloud) / pixels_count
assert ref_valid_fraction <= 1.0 and ref_valid_fraction >= 0.0
if ref_valid_fraction < self.minimum_valid_fraction:
# This should produce a recursion error at some point when something is very wrong.
return self[0]
# Pick an reference observed point in the point cloud
observed_point = np.mean(ref_pointcloud, axis=0)
# Add the first image as reference
viewpoints_observations = [ref_observations]
viewpoints_covisibility = [ref_valid_fraction]
viewpoints_positions = [ref_position]
viewpoints_orientations = [quaternion.as_float_array(ref_orientation)]
viewpoints_clouds = [ref_pointcloud]
viewpoints_valid_fractions = [ref_valid_fraction]
for _ in range(self.views_count - 1):
# Generate an other viewpoint using some dummy random walk
successful_sampling = False
for sampling_attempt in range(self.max_attempts_count):
position, rotation, _ = self.sample_other_random_viewpoint(observed_point, nav_point)
# Observation
other_viewpoint_observations = self.render_viewpoint(position, rotation)
other_pointcloud = compute_pointcloud(other_viewpoint_observations['depth'], self.hfov, position, rotation)
is_valid, valid_fraction, covisibility = self.is_other_pointcloud_overlapping(ref_pointcloud, other_pointcloud)
if is_valid:
successful_sampling = True
break
if not successful_sampling:
print("WARNING: Maximum number of attempts reached.")
# Dirty hack, try using a novel original viewpoint
return self[0]
viewpoints_observations.append(other_viewpoint_observations)
viewpoints_covisibility.append(covisibility)
viewpoints_positions.append(position)
viewpoints_orientations.append(quaternion.as_float_array(rotation)) # WXYZ convention for the quaternion encoding.
viewpoints_clouds.append(other_pointcloud)
viewpoints_valid_fractions.append(valid_fraction)
# Estimate relations between all pairs of images
pairwise_visibility_ratios = np.ones((len(viewpoints_observations), len(viewpoints_observations)))
for i in range(len(viewpoints_observations)):
pairwise_visibility_ratios[i,i] = viewpoints_valid_fractions[i]
for j in range(i+1, len(viewpoints_observations)):
overlap = compute_pointcloud_overlaps_scikit(viewpoints_clouds[i], viewpoints_clouds[j], self.distance_threshold, compute_symmetric=True)
pairwise_visibility_ratios[i,j] = overlap['intersection1'] / pixels_count
pairwise_visibility_ratios[j,i] = overlap['intersection2'] / pixels_count
# IoU is relative to the image 0
data = {"observations": viewpoints_observations,
"positions": np.asarray(viewpoints_positions),
"orientations": np.asarray(viewpoints_orientations),
"covisibility_ratios": np.asarray(viewpoints_covisibility),
"valid_fractions": np.asarray(viewpoints_valid_fractions, dtype=float),
"pairwise_visibility_ratios": np.asarray(pairwise_visibility_ratios, dtype=float),
}
if self.transform is not None:
data = self.transform(data)
return data
def generate_random_spiral_trajectory(self, images_count = 100, max_radius=0.5, half_turns=5, use_constant_orientation=False):
"""
Return a list of images corresponding to a spiral trajectory from a random starting point.
Useful to generate nice visualisations.
Use an even number of half turns to get a nice "C1-continuous" loop effect
"""
ref_position, ref_orientation, navpoint = self.sample_random_viewpoint()
ref_observations = self.render_viewpoint(ref_position, ref_orientation)
ref_pointcloud = compute_pointcloud(depthmap=ref_observations['depth'], hfov=self.hfov,
camera_position=ref_position, camera_rotation=ref_orientation)
pixels_count = self.resolution[0] * self.resolution[1]
if len(ref_pointcloud) / pixels_count < self.minimum_valid_fraction:
# Dirty hack: ensure that the valid part of the image is significant
return self.generate_random_spiral_trajectory(images_count, max_radius, half_turns, use_constant_orientation)
# Pick an observed point in the point cloud
observed_point = np.mean(ref_pointcloud, axis=0)
ref_R, ref_t = compute_camera_pose_opencv_convention(ref_position, ref_orientation)
images = []
is_valid = []
# Spiral trajectory, use_constant orientation
for i, alpha in enumerate(np.linspace(0, 1, images_count)):
r = max_radius * np.abs(np.sin(alpha * np.pi)) # Increase then decrease the radius
theta = alpha * half_turns * np.pi
x = r * np.cos(theta)
y = r * np.sin(theta)
z = 0.0
position = ref_position + (ref_R @ np.asarray([x, y, z]).reshape(3,1)).flatten()
if use_constant_orientation:
orientation = ref_orientation
else:
# trajectory looking at a mean point in front of the ref observation
orientation, position = look_at_for_habitat(eye=position, center=observed_point, up=habitat_sim.geo.UP)
observations = self.render_viewpoint(position, orientation)
images.append(observations['color'][...,:3])
_is_valid, valid_fraction, iou = self.is_other_viewpoint_overlapping(ref_pointcloud, observations, position, orientation)
is_valid.append(_is_valid)
return images, np.all(is_valid)