import os import torch import numpy as np from StructDiffusion.utils.rearrangement import show_pcs_color_order, show_pcs_with_trimesh from StructDiffusion.utils.pointnet import random_point_sample, index_points import StructDiffusion.utils.tra3d as tra3d def move_pc_and_create_scene_new(obj_xyzs, obj_params, struct_pose, current_pc_pose, target_object_inds, device, return_scene_pts=False, return_scene_pts_and_pc_idxs=False, num_scene_pts=None, normalize_pc=False, return_pair_pc=False, num_pair_pc_pts=None, normalize_pair_pc=False): # obj_xyzs: N, P, 3 # obj_params: B, N, 6 # struct_pose: B x N, 4, 4 # current_pc_pose: B x N, 4, 4 # target_object_inds: 1, N B, N, _ = obj_params.shape _, P, _ = obj_xyzs.shape # B, N, 6 flat_obj_params = obj_params.reshape(B * N, -1) goal_pc_pose_in_struct = torch.eye(4).repeat(B * N, 1, 1).to(device) goal_pc_pose_in_struct[:, :3, :3] = tra3d.euler_angles_to_matrix(flat_obj_params[:, 3:], "XYZ") goal_pc_pose_in_struct[:, :3, 3] = flat_obj_params[:, :3] # B x N, 4, 4 goal_pc_pose = struct_pose @ goal_pc_pose_in_struct goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_pose) # cur_batch_size x N, 4, 4 # # important: pytorch3d uses row-major ordering, need to transpose each transformation matrix # transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2)) # # obj_xyzs: N, P, 3 # new_obj_xyzs = obj_xyzs.repeat(B, 1, 1) # new_obj_xyzs = transpose.transform_points(new_obj_xyzs) # a verision that does not rely on pytorch3d new_obj_xyzs = obj_xyzs.repeat(B, 1, 1) # B x N, P, 3 new_obj_xyzs = torch.concat([new_obj_xyzs, torch.ones(B * N, P, 1).to(device)], dim=-1) # B x N, P, 4 new_obj_xyzs = torch.einsum('bij,bkj->bki', goal_pc_transform, new_obj_xyzs)[:, :, :3] # # B x N, P, 3 # put it back to B, N, P, 3 new_obj_xyzs = new_obj_xyzs.reshape(B, N, P, -1) # visualize_batch_pcs(new_obj_xyzs, S, N, P) # initialize the additional outputs subsampled_scene_xyz = None subsampled_pc_idxs = None obj_pair_xyzs = None # =================================== # Pass to discriminator if return_scene_pts: num_indicator = N # add one hot indicator_variables = torch.eye(num_indicator).repeat(B, 1, 1, P).reshape(B, num_indicator, P, num_indicator).to(device) # B, N, P, N # print(indicator_variables.shape) # print(new_obj_xyzs.shape) new_obj_xyzs = torch.cat([new_obj_xyzs, indicator_variables], dim=-1) # B, N, P, 3 + N # combine pcs in each scene scene_xyzs = new_obj_xyzs.reshape(B, N * P, 3 + N) # ToDo: maybe convert this to a batch operation subsampled_scene_xyz = torch.FloatTensor(B, num_scene_pts, 3 + N).to(device) for si, scene_xyz in enumerate(scene_xyzs): # scene_xyz: N*P, 3+N # target_object_inds: 1, N subsample_idx = torch.randint(0, torch.sum(target_object_inds[0]) * P, (num_scene_pts,)).to(device) subsampled_scene_xyz[si] = scene_xyz[subsample_idx] # # debug: # print("-"*50) # if si < 10: # trimesh.PointCloud(scene_xyz[:, :3].cpu().numpy(), colors=[255, 0, 0, 255]).show() # trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 255, 0, 255]).show() # subsampled_scene_xyz: B, num_scene_pts, 3+N # new_obj_xyzs: B, N, P, 3 # goal_pc_pose: B, N, 4, 4 # important: if normalize_pc: subsampled_scene_xyz[:, :, 0:3] = pc_normalize_batch(subsampled_scene_xyz[:, :, 0:3]) # # debug: # for si in range(10): # trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 0, 255, 255]).show() if return_scene_pts_and_pc_idxs: num_indicator = N pc_idxs = torch.arange(0, num_indicator)[:, None].repeat(B, 1, P).reshape(B, num_indicator, P).to(device) # B, N, P # new_obj_xyzs: B, N, P, 3 + 1 # combine pcs in each scene scene_xyzs = new_obj_xyzs.reshape(B, N * P, 3) pc_idxs = pc_idxs.reshape(B, N*P) subsampled_scene_xyz = torch.FloatTensor(B, num_scene_pts, 3).to(device) subsampled_pc_idxs = torch.LongTensor(B, num_scene_pts).to(device) for si, (scene_xyz, pc_idx) in enumerate(zip(scene_xyzs, pc_idxs)): # scene_xyz: N*P, 3+1 # target_object_inds: 1, N subsample_idx = torch.randint(0, torch.sum(target_object_inds[0]) * P, (num_scene_pts,)).to(device) subsampled_scene_xyz[si] = scene_xyz[subsample_idx] subsampled_pc_idxs[si] = pc_idx[subsample_idx] # subsampled_scene_xyz: B, num_scene_pts, 3 # subsampled_pc_idxs: B, num_scene_pts # new_obj_xyzs: B, N, P, 3 # goal_pc_pose: B, N, 4, 4 # important: if normalize_pc: subsampled_scene_xyz[:, :, 0:3] = pc_normalize_batch(subsampled_scene_xyz[:, :, 0:3]) # TODO: visualize each individual object # debug # print(subsampled_scene_xyz.shape) # print(subsampled_pc_idxs.shape) # print("visualize subsampled scene") # for si in range(5): # trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 0, 255, 255]).show() ############################################### # Create input for pairwise collision detector if return_pair_pc: assert num_pair_pc_pts is not None # new_obj_xyzs: B, N, P, 3 + N # target_object_inds: 1, N # ignore paddings num_objs = torch.sum(target_object_inds[0]) obj_pair_idxs = torch.combinations(torch.arange(num_objs), r=2) # num_comb, 2 # use [:, :, :, :3] to get obj_xyzs without object-wise indicator obj_pair_xyzs = new_obj_xyzs[:, :, :, :3][:, obj_pair_idxs] # B, num_comb, 2 (obj 1 and obj 2), P, 3 num_comb = obj_pair_xyzs.shape[1] pair_indicator_variables = torch.eye(2).repeat(B, num_comb, 1, 1, P).reshape(B, num_comb, 2, P, 2).to(device) # B, num_comb, 2, P, 2 obj_pair_xyzs = torch.cat([obj_pair_xyzs, pair_indicator_variables], dim=-1) # B, num_comb, 2, P, 3 (pc channels) + 2 (indicator for obj 1 and obj 2) obj_pair_xyzs = obj_pair_xyzs.reshape(B, num_comb, P * 2, 5) # random sample: idx = np.random.randint(0, scene_xyz.shape[0], self.num_scene_pts) obj_pair_xyzs = obj_pair_xyzs.reshape(B * num_comb, P * 2, 5) # random_point_sample() input dim: B, N, C rand_idxs = random_point_sample(obj_pair_xyzs, num_pair_pc_pts) # B * num_comb, num_pair_pc_pts obj_pair_xyzs = index_points(obj_pair_xyzs, rand_idxs) # B * num_comb, num_pair_pc_pts, 5 if normalize_pair_pc: # pc_normalize_batch() input dim: pc: B, num_scene_pts, 3 # obj_pair_xyzs = obj_pair_xyzs.reshape(B * num_comb, num_pair_pc_pts, 5) obj_pair_xyzs[:, :, 0:3] = pc_normalize_batch(obj_pair_xyzs[:, :, 0:3]) obj_pair_xyzs = obj_pair_xyzs.reshape(B, num_comb, num_pair_pc_pts, 5) # # debug # for bi, this_obj_pair_xyzs in enumerate(obj_pair_xyzs): # print("batch id", bi) # for pi, obj_pair_xyz in enumerate(this_obj_pair_xyzs): # print("pair", pi) # # obj_pair_xyzs: 2 * P, 5 # print(obj_pair_xyz[:, :3].shape) # trimesh.PointCloud(obj_pair_xyz[:, :3].cpu()).show() # obj_pair_xyzs: B, num_comb, num_pair_pc_pts, 3 + 2 goal_pc_pose = goal_pc_pose.reshape(B, N, 4, 4) return new_obj_xyzs, goal_pc_pose, subsampled_scene_xyz, subsampled_pc_idxs, obj_pair_xyzs def compute_current_and_goal_pc_poses(obj_xyzs, struct_pose, pc_poses_in_struct): device = obj_xyzs.device # obj_xyzs: B, N, P, 3 # struct_pose: B, 1, 4, 4 # pc_poses_in_struct: B, N, 4, 4 B, N, _, _ = pc_poses_in_struct.shape _, _, P, _ = obj_xyzs.shape current_pc_poses = torch.eye(4).repeat(B, N, 1, 1).to(device) # B, N, 4, 4 # print(torch.mean(obj_xyzs, dim=2).shape) current_pc_poses[:, :, :3, 3] = torch.mean(obj_xyzs, dim=2) # B, N, 4, 4 struct_pose = struct_pose.repeat(1, N, 1, 1) # B, N, 4, 4 struct_pose = struct_pose.reshape(B * N, 4, 4) # B x 1, 4, 4 pc_poses_in_struct = pc_poses_in_struct.reshape(B * N, 4, 4) # B x N, 4, 4 goal_pc_poses = struct_pose @ pc_poses_in_struct # B x N, 4, 4 goal_pc_poses = goal_pc_poses.reshape(B, N, 4, 4) # B, N, 4, 4 return current_pc_poses, goal_pc_poses def sample_gaussians(mus, sigmas, sample_size): # mus: [number of individual gaussians] # sigmas: [number of individual gaussians] normal = torch.distributions.Normal(mus, sigmas) samples = normal.sample((sample_size,)) # samples: [sample_size, number of individual gaussians] return samples def fit_gaussians(samples, sigma_eps=0.01): device = samples.device # samples: [sample_size, number of individual gaussians] num_gs = samples.shape[1] mus = torch.mean(samples, dim=0).to(device) sigmas = torch.std(samples, dim=0).to(device) + sigma_eps * torch.ones(num_gs).to(device) # mus: [number of individual gaussians] # sigmas: [number of individual gaussians] return mus, sigmas def visualize_batch_pcs(obj_xyzs, B, verbose=False, limit_B=None, save_dir=None, trimesh=False): if limit_B is None: limit_B = B vis_obj_xyzs = obj_xyzs[:limit_B] if torch.is_tensor(vis_obj_xyzs): if vis_obj_xyzs.is_cuda: vis_obj_xyzs = vis_obj_xyzs.detach().cpu() vis_obj_xyzs = vis_obj_xyzs.numpy() for bi, vis_obj_xyz in enumerate(vis_obj_xyzs): if verbose: print("example {}".format(bi)) print(vis_obj_xyz.shape) if trimesh: show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz]) else: if save_dir: if not os.path.exists(save_dir): os.makedirs(save_dir) save_path = os.path.join(save_dir, "b{}.jpg".format(bi)) show_pcs_color_order([xyz[:, :3] for xyz in vis_obj_xyz], None, visualize=False, add_coordinate_frame=False, side_view=True, save_path=save_path) else: show_pcs_color_order([xyz[:, :3] for xyz in vis_obj_xyz], None, visualize=True, add_coordinate_frame=False, side_view=True) def pc_normalize_batch(pc): # pc: B, num_scene_pts, 3 centroid = torch.mean(pc, dim=1) # B, 3 pc = pc - centroid[:, None, :] m = torch.max(torch.sqrt(torch.sum(pc ** 2, dim=2)), dim=1)[0] pc = pc / m[:, None, None] return pc