Spaces:
Paused
Paused
import cv2 | |
import h5py | |
import numpy as np | |
import os | |
import trimesh | |
import torch | |
import json | |
from collections import defaultdict | |
import tqdm | |
import pickle | |
from random import shuffle | |
# Local imports | |
from StructDiffusion.utils.rearrangement import show_pcs, get_pts, array_to_tensor | |
from StructDiffusion.utils.pointnet import pc_normalize | |
import StructDiffusion.utils.brain2.camera as cam | |
import StructDiffusion.utils.brain2.image as img | |
import StructDiffusion.utils.transformations as tra | |
def load_pairwise_collision_data(h5_filename): | |
fh = h5py.File(h5_filename, 'r') | |
data_dict = {} | |
data_dict["obj1_info"] = eval(fh["obj1_info"][()]) | |
data_dict["obj2_info"] = eval(fh["obj2_info"][()]) | |
data_dict["obj1_poses"] = fh["obj1_poses"][:] | |
data_dict["obj2_poses"] = fh["obj2_poses"][:] | |
data_dict["intersection_labels"] = fh["intersection_labels"][:] | |
return data_dict | |
def replace_root_directory(original_filename: str, new_root: str) -> str: | |
# Split the original filename into a list by directory | |
original_parts = original_filename.split('/') | |
# Find the index of the "data_new_objects" part | |
data_index = original_parts.index('data_new_objects') | |
# Split the new root into a list by directory | |
new_root_parts = new_root.split('/') | |
# Combine the new root with the rest of the original filename | |
updated_filename = '/'.join(new_root_parts + original_parts[data_index + 1:]) | |
return updated_filename | |
class PairwiseCollisionDataset(torch.utils.data.Dataset): | |
def __init__(self, urdf_pc_idx_file, collision_data_dir, random_rotation=True, | |
num_pts=1024, normalize_pc=True, num_scene_pts=2048, data_augmentation=False, | |
debug=False, new_data_root=None): | |
# load dictionary mapping from urdf to list of pc data, each sample is | |
# {"step_t": step_t, "obj": obj, "filename": filename} | |
with open(urdf_pc_idx_file, "rb") as fh: | |
self.urdf_to_pc_data = pickle.load(fh) | |
# filter out broken files | |
for urdf in self.urdf_to_pc_data: | |
valid_pc_data = [] | |
for pd in self.urdf_to_pc_data[urdf]: | |
filename = pd["filename"] | |
if "data00026058" in filename or "data00011415" in filename or "data00026061" in filename or "data00700565" in filename or "data00505290" in filename: | |
continue | |
if new_data_root: | |
pd["filename"] = replace_root_directory(pd["filename"], new_data_root) | |
valid_pc_data.append(pd) | |
if valid_pc_data: | |
self.urdf_to_pc_data[urdf] = valid_pc_data | |
# build data index | |
# each sample is a tuple of (collision filename, idx for the labels and poses) | |
if collision_data_dir is not None: | |
self.data_idxs = self.build_data_idxs(collision_data_dir) | |
else: | |
print("WARNING: collision_data_dir is None") | |
self.num_pts = num_pts | |
self.debug = debug | |
self.normalize_pc = normalize_pc | |
self.num_scene_pts = num_scene_pts | |
self.random_rotation = random_rotation | |
# Noise | |
self.data_augmentation = data_augmentation | |
# additive noise | |
self.gp_rescale_factor_range = [12, 20] | |
self.gaussian_scale_range = [0., 0.003] | |
# multiplicative noise | |
self.gamma_shape = 1000. | |
self.gamma_scale = 0.001 | |
def build_data_idxs(self, collision_data_dir): | |
print("Load collision data...") | |
positive_data = [] | |
negative_data = [] | |
for filename in tqdm.tqdm(os.listdir(collision_data_dir)): | |
if "h5" not in filename: | |
continue | |
h5_filename = os.path.join(collision_data_dir, filename) | |
data_dict = load_pairwise_collision_data(h5_filename) | |
obj1_urdf = data_dict["obj1_info"]["urdf"] | |
obj2_urdf = data_dict["obj2_info"]["urdf"] | |
if obj1_urdf not in self.urdf_to_pc_data: | |
print("no pc data for urdf:", obj1_urdf) | |
continue | |
if obj2_urdf not in self.urdf_to_pc_data: | |
print("no pc data for urdf:", obj2_urdf) | |
continue | |
for idx, l in enumerate(data_dict["intersection_labels"]): | |
if l: | |
# intersection | |
positive_data.append((h5_filename, idx)) | |
else: | |
negative_data.append((h5_filename, idx)) | |
print("Num pairwise intersections:", len(positive_data)) | |
print("Num pairwise no intersections:", len(negative_data)) | |
if len(negative_data) != len(positive_data): | |
min_len = min(len(negative_data), len(positive_data)) | |
positive_data = [positive_data[i] for i in np.random.permutation(len(positive_data))[:min_len]] | |
negative_data = [negative_data[i] for i in np.random.permutation(len(negative_data))[:min_len]] | |
print("after balancing") | |
print("Num pairwise intersections:", len(positive_data)) | |
print("Num pairwise no intersections:", len(negative_data)) | |
return positive_data + negative_data | |
def create_urdf_pc_idxs(self, urdf_pc_idx_file, data_roots, index_roots): | |
print("Load pc data") | |
arrangement_steps = [] | |
for split in ["train"]: | |
for data_root, index_root in zip(data_roots, index_roots): | |
arrangement_indices_file = os.path.join(data_root, index_root,"{}_arrangement_indices_file_all.txt".format(split)) | |
if os.path.exists(arrangement_indices_file): | |
with open(arrangement_indices_file, "r") as fh: | |
arrangement_steps.extend([(os.path.join(data_root, f[0]), f[1]) for f in eval(fh.readline().strip())]) | |
else: | |
print("{} does not exist".format(arrangement_indices_file)) | |
urdf_to_pc_data = defaultdict(list) | |
for filename, step_t in tqdm.tqdm(arrangement_steps): | |
h5 = h5py.File(filename, 'r') | |
ids = self._get_ids(h5) | |
# moved_objs = h5['moved_objs'][()].split(',') | |
all_objs = sorted([o for o in ids.keys() if "object_" in o]) | |
goal_specification = json.loads(str(np.array(h5["goal_specification"]))) | |
obj_infos = goal_specification["rearrange"]["objects"] + goal_specification["anchor"]["objects"] + goal_specification["distract"]["objects"] | |
for obj, obj_info in zip(all_objs, obj_infos): | |
urdf_to_pc_data[obj_info["urdf"]].append({"step_t": step_t, "obj": obj, "filename": filename}) | |
with open(urdf_pc_idx_file, "wb") as fh: | |
pickle.dump(urdf_to_pc_data, fh) | |
return urdf_to_pc_data | |
def add_noise_to_depth(self, depth_img): | |
""" add depth noise """ | |
multiplicative_noise = np.random.gamma(self.gamma_shape, self.gamma_scale) | |
depth_img = multiplicative_noise * depth_img | |
return depth_img | |
def add_noise_to_xyz(self, xyz_img, depth_img): | |
""" TODO: remove this code or at least celean it up""" | |
xyz_img = xyz_img.copy() | |
H, W, C = xyz_img.shape | |
gp_rescale_factor = np.random.randint(self.gp_rescale_factor_range[0], | |
self.gp_rescale_factor_range[1]) | |
gp_scale = np.random.uniform(self.gaussian_scale_range[0], | |
self.gaussian_scale_range[1]) | |
small_H, small_W = (np.array([H, W]) / gp_rescale_factor).astype(int) | |
additive_noise = np.random.normal(loc=0.0, scale=gp_scale, size=(small_H, small_W, C)) | |
additive_noise = cv2.resize(additive_noise, (W, H), interpolation=cv2.INTER_CUBIC) | |
xyz_img[depth_img > 0, :] += additive_noise[depth_img > 0, :] | |
return xyz_img | |
def _get_images(self, h5, idx, ee=True): | |
if ee: | |
RGB, DEPTH, SEG = "ee_rgb", "ee_depth", "ee_seg" | |
DMIN, DMAX = "ee_depth_min", "ee_depth_max" | |
else: | |
RGB, DEPTH, SEG = "rgb", "depth", "seg" | |
DMIN, DMAX = "depth_min", "depth_max" | |
dmin = h5[DMIN][idx] | |
dmax = h5[DMAX][idx] | |
rgb1 = img.PNGToNumpy(h5[RGB][idx])[:, :, :3] / 255. # remove alpha | |
depth1 = h5[DEPTH][idx] / 20000. * (dmax - dmin) + dmin | |
seg1 = img.PNGToNumpy(h5[SEG][idx]) | |
valid1 = np.logical_and(depth1 > 0.1, depth1 < 2.) | |
# proj_matrix = h5['proj_matrix'][()] | |
camera = cam.get_camera_from_h5(h5) | |
if self.data_augmentation: | |
depth1 = self.add_noise_to_depth(depth1) | |
xyz1 = cam.compute_xyz(depth1, camera) | |
if self.data_augmentation: | |
xyz1 = self.add_noise_to_xyz(xyz1, depth1) | |
# Transform the point cloud | |
# Here it is... | |
# CAM_POSE = "ee_cam_pose" if ee else "cam_pose" | |
CAM_POSE = "ee_camera_view" if ee else "camera_view" | |
cam_pose = h5[CAM_POSE][idx] | |
if ee: | |
# ee_camera_view has 0s for x, y, z | |
cam_pos = h5["ee_cam_pose"][:][:3, 3] | |
cam_pose[:3, 3] = cam_pos | |
# Get transformed point cloud | |
h, w, d = xyz1.shape | |
xyz1 = xyz1.reshape(h * w, -1) | |
xyz1 = trimesh.transform_points(xyz1, cam_pose) | |
xyz1 = xyz1.reshape(h, w, -1) | |
scene1 = rgb1, depth1, seg1, valid1, xyz1 | |
return scene1 | |
def _get_ids(self, h5): | |
""" | |
get object ids | |
@param h5: | |
@return: | |
""" | |
ids = {} | |
for k in h5.keys(): | |
if k.startswith("id_"): | |
ids[k[3:]] = h5[k][()] | |
return ids | |
def get_obj_pc(self, h5, step_t, obj): | |
scene = self._get_images(h5, step_t, ee=True) | |
rgb, depth, seg, valid, xyz = scene | |
# getting object point clouds | |
ids = self._get_ids(h5) | |
obj_mask = np.logical_and(seg == ids[obj], valid) | |
if np.sum(obj_mask) <= 0: | |
raise Exception | |
ok, obj_xyz, obj_rgb, _ = get_pts(xyz, rgb, obj_mask, num_pts=self.num_pts, to_tensor=False) | |
obj_pc_center = np.mean(obj_xyz, axis=0) | |
obj_pose = h5[obj][step_t] | |
obj_pc_pose = np.eye(4) | |
obj_pc_pose[:3, 3] = obj_pc_center[:3] | |
return obj_xyz, obj_rgb, obj_pc_pose, obj_pose | |
def __len__(self): | |
return len(self.data_idxs) | |
def __getitem__(self, idx): | |
collision_filename, collision_idx = self.data_idxs[idx] | |
collision_data_dict = load_pairwise_collision_data(collision_filename) | |
obj1_urdf = collision_data_dict["obj1_info"]["urdf"] | |
obj2_urdf = collision_data_dict["obj2_info"]["urdf"] | |
# TODO: find a better way to sample pc data? | |
obj1_pc_data = np.random.choice(self.urdf_to_pc_data[obj1_urdf]) | |
obj2_pc_data = np.random.choice(self.urdf_to_pc_data[obj2_urdf]) | |
obj1_xyz, obj1_rgb, obj1_pc_pose, obj1_pose = self.get_obj_pc(h5py.File(obj1_pc_data["filename"], "r"), obj1_pc_data["step_t"], obj1_pc_data["obj"]) | |
obj2_xyz, obj2_rgb, obj2_pc_pose, obj2_pose = self.get_obj_pc(h5py.File(obj2_pc_data["filename"], "r"), obj2_pc_data["step_t"], obj2_pc_data["obj"]) | |
obj1_c_pose = collision_data_dict["obj1_poses"][collision_idx] | |
obj2_c_pose = collision_data_dict["obj2_poses"][collision_idx] | |
label = collision_data_dict["intersection_labels"][collision_idx] | |
obj1_transform = obj1_c_pose @ np.linalg.inv(obj1_pose) | |
obj2_transform = obj2_c_pose @ np.linalg.inv(obj2_pose) | |
obj1_c_xyz = trimesh.transform_points(obj1_xyz, obj1_transform) | |
obj2_c_xyz = trimesh.transform_points(obj2_xyz, obj2_transform) | |
# if self.debug: | |
# show_pcs([obj1_c_xyz, obj2_c_xyz], [obj1_rgb, obj2_rgb], add_coordinate_frame=True) | |
################################### | |
obj_xyzs = [obj1_c_xyz, obj2_c_xyz] | |
shuffle(obj_xyzs) | |
num_indicator = 2 | |
new_obj_xyzs = [] | |
for oi, obj_xyz in enumerate(obj_xyzs): | |
obj_xyz = np.concatenate([obj_xyz, np.tile(np.eye(num_indicator)[oi], (obj_xyz.shape[0], 1))], axis=1) | |
new_obj_xyzs.append(obj_xyz) | |
scene_xyz = np.concatenate(new_obj_xyzs, axis=0) | |
# subsampling and normalizing pc | |
idx = np.random.randint(0, scene_xyz.shape[0], self.num_scene_pts) | |
scene_xyz = scene_xyz[idx] | |
if self.normalize_pc: | |
scene_xyz[:, 0:3] = pc_normalize(scene_xyz[:, 0:3]) | |
if self.random_rotation: | |
scene_xyz[:, 0:3] = trimesh.transform_points(scene_xyz[:, 0:3], tra.euler_matrix(0, 0, np.random.uniform(low=0, high=2 * np.pi))) | |
################################### | |
scene_xyz = array_to_tensor(scene_xyz) | |
# convert to torch data | |
label = int(label) | |
if self.debug: | |
print("intersection:", label) | |
show_pcs([scene_xyz[:, 0:3]], [np.tile(np.array([0, 1, 0], dtype=np.float), (scene_xyz.shape[0], 1))], add_coordinate_frame=True) | |
datum = { | |
"scene_xyz": scene_xyz, | |
"label": torch.FloatTensor([label]), | |
} | |
return datum | |