Weiyu Liu
add demo
8c02843
raw
history blame
11.7 kB
import sys
import os
import h5py
import torch
import pytorch3d.transforms as tra3d
from StructDiffusion.utils.rearrangement import show_pcs_color_order
from StructDiffusion.utils.pointnet import random_point_sample, index_points
def switch_stdout(stdout_filename=None):
if stdout_filename:
print("setting stdout to {}".format(stdout_filename))
if os.path.exists(stdout_filename):
sys.stdout = open(stdout_filename, 'a')
else:
sys.stdout = open(stdout_filename, 'w')
else:
sys.stdout = sys.__stdout__
def visualize_batch_pcs(obj_xyzs, B, N, P, verbose=True, limit_B=None):
if limit_B is None:
limit_B = B
vis_obj_xyzs = obj_xyzs.reshape(B, N, P, -1)
vis_obj_xyzs = vis_obj_xyzs[:limit_B]
if type(vis_obj_xyzs).__module__ == torch.__name__:
if vis_obj_xyzs.is_cuda:
vis_obj_xyzs = vis_obj_xyzs.detach().cpu()
vis_obj_xyzs = vis_obj_xyzs.numpy()
for bi, vis_obj_xyz in enumerate(vis_obj_xyzs):
if verbose:
print("example {}".format(bi))
print(vis_obj_xyz.shape)
show_pcs_color_order([xyz[:, :3] for xyz in vis_obj_xyz], None, visualize=True, add_coordinate_frame=True, add_table=False)
def convert_bool(d):
for k in d:
if type(d[k]) == list:
d[k] = [bool(i) for i in d[k]]
else:
d[k] = bool(d[k])
return d
def save_dict_to_h5(dict_data, filename):
fh = h5py.File(filename, 'w')
for k in dict_data:
key_data = dict_data[k]
if key_data is None:
raise RuntimeError('data was not properly populated')
# if type(key_data) is dict:
# key_data = json.dumps(key_data, sort_keys=True)
try:
fh.create_dataset(k, data=key_data)
except TypeError as e:
print("Failure on key", k)
print(key_data)
print(e)
raise e
fh.close()
def move_pc_and_create_scene_new(obj_xyzs, obj_params, struct_pose, current_pc_pose, target_object_inds, device,
return_scene_pts=False, return_scene_pts_and_pc_idxs=False, num_scene_pts=None, normalize_pc=False,
return_pair_pc=False, num_pair_pc_pts=None, normalize_pair_pc=False):
# obj_xyzs: N, P, 3
# obj_params: B, N, 6
# struct_pose: B x N, 4, 4
# current_pc_pose: B x N, 4, 4
# target_object_inds: 1, N
B, N, _ = obj_params.shape
_, P, _ = obj_xyzs.shape
# B, N, 6
flat_obj_params = obj_params.reshape(B * N, -1)
goal_pc_pose_in_struct = torch.eye(4).repeat(B * N, 1, 1).to(device)
goal_pc_pose_in_struct[:, :3, :3] = tra3d.euler_angles_to_matrix(flat_obj_params[:, 3:], "XYZ")
goal_pc_pose_in_struct[:, :3, 3] = flat_obj_params[:, :3] # B x N, 4, 4
goal_pc_pose = struct_pose @ goal_pc_pose_in_struct
goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_pose) # cur_batch_size x N, 4, 4
# important: pytorch3d uses row-major ordering, need to transpose each transformation matrix
transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2))
# obj_xyzs: N, P, 3
new_obj_xyzs = obj_xyzs.repeat(B, 1, 1)
new_obj_xyzs = transpose.transform_points(new_obj_xyzs)
# put it back to B, N, P, 3
new_obj_xyzs = new_obj_xyzs.reshape(B, N, P, -1)
# visualize_batch_pcs(new_obj_xyzs, S, N, P)
# initialize the additional outputs
subsampled_scene_xyz = None
subsampled_pc_idxs = None
obj_pair_xyzs = None
# ===================================
# Pass to discriminator
if return_scene_pts:
num_indicator = N
# add one hot
indicator_variables = torch.eye(num_indicator).repeat(B, 1, 1, P).reshape(B, num_indicator, P, num_indicator).to(device) # B, N, P, N
# print(indicator_variables.shape)
# print(new_obj_xyzs.shape)
new_obj_xyzs = torch.cat([new_obj_xyzs, indicator_variables], dim=-1) # B, N, P, 3 + N
# combine pcs in each scene
scene_xyzs = new_obj_xyzs.reshape(B, N * P, 3 + N)
# ToDo: maybe convert this to a batch operation
subsampled_scene_xyz = torch.FloatTensor(B, num_scene_pts, 3 + N).to(device)
for si, scene_xyz in enumerate(scene_xyzs):
# scene_xyz: N*P, 3+N
# target_object_inds: 1, N
subsample_idx = torch.randint(0, torch.sum(target_object_inds[0]) * P, (num_scene_pts,)).to(device)
subsampled_scene_xyz[si] = scene_xyz[subsample_idx]
# # debug:
# print("-"*50)
# if si < 10:
# trimesh.PointCloud(scene_xyz[:, :3].cpu().numpy(), colors=[255, 0, 0, 255]).show()
# trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 255, 0, 255]).show()
# subsampled_scene_xyz: B, num_scene_pts, 3+N
# new_obj_xyzs: B, N, P, 3
# goal_pc_pose: B, N, 4, 4
# important:
if normalize_pc:
subsampled_scene_xyz[:, :, 0:3] = pc_normalize_batch(subsampled_scene_xyz[:, :, 0:3])
# # debug:
# for si in range(10):
# trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 0, 255, 255]).show()
if return_scene_pts_and_pc_idxs:
num_indicator = N
pc_idxs = torch.arange(0, num_indicator)[:, None].repeat(B, 1, P).reshape(B, num_indicator, P).to(device) # B, N, P
# new_obj_xyzs: B, N, P, 3 + 1
# combine pcs in each scene
scene_xyzs = new_obj_xyzs.reshape(B, N * P, 3)
pc_idxs = pc_idxs.reshape(B, N*P)
subsampled_scene_xyz = torch.FloatTensor(B, num_scene_pts, 3).to(device)
subsampled_pc_idxs = torch.LongTensor(B, num_scene_pts).to(device)
for si, (scene_xyz, pc_idx) in enumerate(zip(scene_xyzs, pc_idxs)):
# scene_xyz: N*P, 3+1
# target_object_inds: 1, N
subsample_idx = torch.randint(0, torch.sum(target_object_inds[0]) * P, (num_scene_pts,)).to(device)
subsampled_scene_xyz[si] = scene_xyz[subsample_idx]
subsampled_pc_idxs[si] = pc_idx[subsample_idx]
# subsampled_scene_xyz: B, num_scene_pts, 3
# subsampled_pc_idxs: B, num_scene_pts
# new_obj_xyzs: B, N, P, 3
# goal_pc_pose: B, N, 4, 4
# important:
if normalize_pc:
subsampled_scene_xyz[:, :, 0:3] = pc_normalize_batch(subsampled_scene_xyz[:, :, 0:3])
# TODO: visualize each individual object
# debug
# print(subsampled_scene_xyz.shape)
# print(subsampled_pc_idxs.shape)
# print("visualize subsampled scene")
# for si in range(5):
# trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 0, 255, 255]).show()
###############################################
# Create input for pairwise collision detector
if return_pair_pc:
assert num_pair_pc_pts is not None
# new_obj_xyzs: B, N, P, 3 + N
# target_object_inds: 1, N
# ignore paddings
num_objs = torch.sum(target_object_inds[0])
obj_pair_idxs = torch.combinations(torch.arange(num_objs), r=2) # num_comb, 2
# use [:, :, :, :3] to get obj_xyzs without object-wise indicator
obj_pair_xyzs = new_obj_xyzs[:, :, :, :3][:, obj_pair_idxs] # B, num_comb, 2 (obj 1 and obj 2), P, 3
num_comb = obj_pair_xyzs.shape[1]
pair_indicator_variables = torch.eye(2).repeat(B, num_comb, 1, 1, P).reshape(B, num_comb, 2, P, 2).to(device) # B, num_comb, 2, P, 2
obj_pair_xyzs = torch.cat([obj_pair_xyzs, pair_indicator_variables], dim=-1) # B, num_comb, 2, P, 3 (pc channels) + 2 (indicator for obj 1 and obj 2)
obj_pair_xyzs = obj_pair_xyzs.reshape(B, num_comb, P * 2, 5)
# random sample: idx = np.random.randint(0, scene_xyz.shape[0], self.num_scene_pts)
obj_pair_xyzs = obj_pair_xyzs.reshape(B * num_comb, P * 2, 5)
# random_point_sample() input dim: B, N, C
rand_idxs = random_point_sample(obj_pair_xyzs, num_pair_pc_pts) # B * num_comb, num_pair_pc_pts
obj_pair_xyzs = index_points(obj_pair_xyzs, rand_idxs) # B * num_comb, num_pair_pc_pts, 5
if normalize_pair_pc:
# pc_normalize_batch() input dim: pc: B, num_scene_pts, 3
# obj_pair_xyzs = obj_pair_xyzs.reshape(B * num_comb, num_pair_pc_pts, 5)
obj_pair_xyzs[:, :, 0:3] = pc_normalize_batch(obj_pair_xyzs[:, :, 0:3])
obj_pair_xyzs = obj_pair_xyzs.reshape(B, num_comb, num_pair_pc_pts, 5)
# # debug
# for bi, this_obj_pair_xyzs in enumerate(obj_pair_xyzs):
# print("batch id", bi)
# for pi, obj_pair_xyz in enumerate(this_obj_pair_xyzs):
# print("pair", pi)
# # obj_pair_xyzs: 2 * P, 5
# print(obj_pair_xyz[:, :3].shape)
# trimesh.PointCloud(obj_pair_xyz[:, :3].cpu()).show()
# obj_pair_xyzs: B, num_comb, num_pair_pc_pts, 3 + 2
goal_pc_pose = goal_pc_pose.reshape(B, N, 4, 4)
return new_obj_xyzs, goal_pc_pose, subsampled_scene_xyz, subsampled_pc_idxs, obj_pair_xyzs
def move_pc(obj_xyzs, obj_params, struct_pose, current_pc_pose, device):
# obj_xyzs: N, P, 3
# obj_params: B, N, 6
# struct_pose: B x N, 4, 4
# current_pc_pose: B x N, 4, 4
# target_object_inds: 1, N
B, N, _ = obj_params.shape
_, P, _ = obj_xyzs.shape
# B, N, 6
flat_obj_params = obj_params.reshape(B * N, -1)
goal_pc_pose_in_struct = torch.eye(4).repeat(B * N, 1, 1).to(device)
goal_pc_pose_in_struct[:, :3, :3] = tra3d.euler_angles_to_matrix(flat_obj_params[:, 3:], "XYZ")
goal_pc_pose_in_struct[:, :3, 3] = flat_obj_params[:, :3] # B x N, 4, 4
goal_pc_pose = struct_pose @ goal_pc_pose_in_struct
goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_pose) # cur_batch_size x N, 4, 4
# important: pytorch3d uses row-major ordering, need to transpose each transformation matrix
transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2))
# obj_xyzs: N, P, 3
new_obj_xyzs = obj_xyzs.repeat(B, 1, 1)
new_obj_xyzs = transpose.transform_points(new_obj_xyzs)
# put it back to B, N, P, 3
new_obj_xyzs = new_obj_xyzs.reshape(B, N, P, -1)
# visualize_batch_pcs(new_obj_xyzs, S, N, P)
# subsampled_scene_xyz: B, num_scene_pts, 3+N
# new_obj_xyzs: B, N, P, 3
# goal_pc_pose: B, N, 4, 4
goal_pc_pose = goal_pc_pose.reshape(B, N, 4, 4)
return new_obj_xyzs, goal_pc_pose
def sample_gaussians(mus, sigmas, sample_size):
# mus: [number of individual gaussians]
# sigmas: [number of individual gaussians]
normal = torch.distributions.Normal(mus, sigmas)
samples = normal.sample((sample_size,))
# samples: [sample_size, number of individual gaussians]
return samples
def fit_gaussians(samples, sigma_eps=0.01):
device = samples.device
# samples: [sample_size, number of individual gaussians]
num_gs = samples.shape[1]
mus = torch.mean(samples, dim=0).to(device)
sigmas = torch.std(samples, dim=0).to(device) + sigma_eps * torch.ones(num_gs).to(device)
# mus: [number of individual gaussians]
# sigmas: [number of individual gaussians]
return mus, sigmas
def pc_normalize_batch(pc):
# pc: B, num_scene_pts, 3
centroid = torch.mean(pc, dim=1) # B, 3
pc = pc - centroid[:, None, :]
m = torch.max(torch.sqrt(torch.sum(pc ** 2, dim=2)), dim=1)[0]
pc = pc / m[:, None, None]
return pc