import argparse import sys sys.path.append("..") sys.path.append(".") import numpy as np import mcubes import os import torch import trimesh from datasets.SingleView_dataset import Object_PartialPoints_MultiImg from datasets.transforms import Scale_Shift_Rotate from models import get_model from pathlib import Path import open3d as o3d from configs.config_utils import CONFIG import cv2 from util.misc import MetricLogger import scipy from pyTorchChamferDistance.chamfer_distance import ChamferDistance from util.projection_utils import draw_proj_image from util import misc import time dist_chamfer=ChamferDistance() def pc_metrics(p1, p2, space_ext=2, fscore_param=0.01, scale=.5): """ p2: reference ponits (B, N, 3) """ p1, p2, space_ext = p1 * scale, p2 * scale, space_ext * scale f_thresh = space_ext * fscore_param #print(p1.shape,p2.shape) d1, d2, _, _ = dist_chamfer(p1, p2) #print(d1.shape,d2.shape) d1sqrt, d2sqrt = (d1 ** .5), (d2 ** .5) chamfer_L1 = d1sqrt.mean(axis=-1) + d2sqrt.mean(axis=-1) chamfer_L2 = d1.mean(axis=-1) + d2.mean(axis=-1) precision = (d1sqrt < f_thresh).sum(axis=-1).float() / p1.shape[1] recall = (d2sqrt < f_thresh).sum(axis=-1).float() / p2.shape[1] #print(precision,recall) fscore = 2 * torch.div(recall * precision, recall + precision) fscore[fscore == float("inf")] = 0 return chamfer_L1,chamfer_L2,fscore if __name__ == "__main__": parser = argparse.ArgumentParser('this script can be used to compute iou fscore chamfer distance before icp align', add_help=False) parser.add_argument('--configs',type=str,required=True) parser.add_argument('--output_folder', type=str, default="../output_result/Triplane_diff_parcond_0926") parser.add_argument('--dm-pth',type=str) parser.add_argument('--ae-pth',type=str) parser.add_argument('--data-pth', type=str,default="../") parser.add_argument('--save_mesh',action="store_true",default=False) parser.add_argument('--save_image',action="store_true",default=False) parser.add_argument('--save_par_points', action="store_true", default=False) parser.add_argument('--save_proj_img',action="store_true",default=False) parser.add_argument('--save_surface',action="store_true",default=False) parser.add_argument('--reso',default=128,type=int) parser.add_argument('--category',nargs="+",type=str) parser.add_argument('--eval_cd',action="store_true",default=False) parser.add_argument('--use_augmentation',action="store_true",default=False) parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--dist_on_itp', action='store_true') parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') parser.add_argument('--device', default='cuda', help='device to use for training / testing') args = parser.parse_args() misc.init_distributed_mode(args) config_path=args.configs config=CONFIG(config_path) dataset_config=config.config['dataset'] dataset_config['data_path']=args.data_pth if "arkit" in args.category[0]: split_filename=dataset_config['keyword']+'_val_par_img.json' else: split_filename='val_par_img.json' transform = None if args.use_augmentation: transform=Scale_Shift_Rotate(jitter_partial=False,jitter=False,use_scale=False,angle=(-10,10),shift=(-0.1,0.1)) dataset_val = Object_PartialPoints_MultiImg(dataset_config['data_path'], split_filename=split_filename,categories=args.category,split="val", transform=transform, sampling=False, num_samples=1024, return_surface=True,ret_sample=True, surface_sampling=True, par_pc_size=dataset_config['par_pc_size'],surface_size=100000, load_proj_mat=True,load_image=True,load_org_img=True,load_triplane=None,par_point_aug=None,replica=1) batch_size=1 num_tasks = misc.get_world_size() global_rank = misc.get_rank() val_sampler = torch.utils.data.DistributedSampler( dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) # shu dataloader_val=torch.utils.data.DataLoader( dataset_val, sampler=val_sampler, batch_size=batch_size, num_workers=10, shuffle=False, ) output_folder=args.output_folder device = torch.device('cuda') ae_config=config.config['model']['ae'] dm_config=config.config['model']['dm'] ae_model=get_model(ae_config).to(device) if args.category[0] == "all": dm_config["use_cat_embedding"]=True else: dm_config["use_cat_embedding"] = False dm_model=get_model(dm_config).to(device) ae_model.eval() dm_model.eval() ae_model.load_state_dict(torch.load(args.ae_pth)['model']) dm_model.load_state_dict(torch.load(args.dm_pth)['model']) density = args.reso gap = 2.2 / density x = np.linspace(-1.1, 1.1, int(density + 1)) y = np.linspace(-1.1, 1.1, int(density + 1)) z = np.linspace(-1.1, 1.1, int(density + 1)) xv, yv, zv = np.meshgrid(x, y, z,indexing='ij') grid = torch.from_numpy(np.stack([xv, yv, zv]).astype(np.float32)).view(3, -1).transpose(0, 1)[None].to(device,non_blocking=True) metric_logger=MetricLogger(delimiter=" ") header = 'Test:' with torch.no_grad(): for data_batch in metric_logger.log_every(dataloader_val,10, header): # if data_iter_step==100: # break partial_name = data_batch['partial_name'] class_name = data_batch['class_name'] model_ids=data_batch['model_id'] surface=data_batch['surface'] proj_matrices=data_batch['proj_mat'] sample_points=data_batch["points"].cuda().float() labels=data_batch["labels"].cuda().float() sample_input=dm_model.prepare_sample_data(data_batch) #t1 = time.time() sampled_array = dm_model.sample(sample_input,num_steps=36).float() #t2 = time.time() #sample_time = t2 - t1 #print("sampling time %f" % (sample_time)) sampled_array = torch.nn.functional.interpolate(sampled_array, scale_factor=2, mode="bilinear") for j in range(sampled_array.shape[0]): if args.save_mesh | args.save_par_points | args.save_image: object_folder = os.path.join(output_folder, class_name[j], model_ids[j]) Path(object_folder).mkdir(parents=True, exist_ok=True) '''calculate iou''' sample_point=sample_points[j:j+1] sample_output=ae_model.decode(sampled_array[j:j + 1],sample_point) sample_pred=torch.zeros_like(sample_output) sample_pred[sample_output>=0.0]=1 label=labels[j:j+1] intersection = (sample_pred * label).sum(dim=1) union = (sample_pred + label).gt(0).sum(dim=1) iou = intersection * 1.0 / union + 1e-5 iou = iou.mean() metric_logger.update(iou=iou.item()) if args.use_augmentation: tran_mat=data_batch["tran_mat"][j].numpy() mat_save_path='{}/tran_mat.npy'.format(object_folder) np.save(mat_save_path,tran_mat) if args.eval_cd: grid_list=torch.split(grid,128**3,dim=1) output_list=[] #t3=time.time() for sub_grid in grid_list: output_list.append(ae_model.decode(sampled_array[j:j + 1],sub_grid)) output=torch.cat(output_list,dim=1) #t4=time.time() #decoding_time=t4-t3 #print("decoding time:",decoding_time) logits = output[j].detach() volume = logits.view(density + 1, density + 1, density + 1).cpu().numpy() verts, faces = mcubes.marching_cubes(volume, 0) verts *= gap verts -= 1.1 #print("vertice max min",np.amin(verts,axis=0),np.amax(verts,axis=0)) m = trimesh.Trimesh(verts, faces) '''calculate fscore and chamfer distance''' result_surface,_=trimesh.sample.sample_surface(m,100000) gt_surface=surface[j] assert gt_surface.shape[0]==result_surface.shape[0] result_surface_gpu = torch.from_numpy(result_surface).float().cuda().unsqueeze(0) gt_surface_gpu = gt_surface.float().cuda().unsqueeze(0) _,chamfer_L2,fscore=pc_metrics(result_surface_gpu,gt_surface_gpu) metric_logger.update(chamferl2=chamfer_L2*1000.0) metric_logger.update(fscore=fscore) if args.save_mesh: m.export('{}/{}_mesh.ply'.format(object_folder, partial_name[j])) if args.save_par_points: par_point_input = data_batch['par_points'][j].numpy() #print("input max min", np.amin(par_point_input, axis=0), np.amax(par_point_input, axis=0)) par_point_o3d = o3d.geometry.PointCloud() par_point_o3d.points = o3d.utility.Vector3dVector(par_point_input[:, 0:3]) o3d.io.write_point_cloud('{}/{}.ply'.format(object_folder, partial_name[j]), par_point_o3d) if args.save_image: image_list=data_batch["org_image"] for idx,image in enumerate(image_list): image=image[0].numpy().astype(np.uint8) if args.save_proj_img: proj_mat=proj_matrices[j,idx].numpy() proj_image=draw_proj_image(image,proj_mat,result_surface) proj_save_path = '{}/proj_{}.jpg'.format(object_folder, idx) cv2.imwrite(proj_save_path,proj_image) save_path='{}/{}.jpg'.format(object_folder, idx) cv2.imwrite(save_path,image) if args.save_surface: surface=gt_surface.numpy().astype(np.float32) surface_o3d = o3d.geometry.PointCloud() surface_o3d.points = o3d.utility.Vector3dVector(surface[:, 0:3]) o3d.io.write_point_cloud('{}/surface.ply'.format(object_folder), surface_o3d) metric_logger.synchronize_between_processes() print('* iou {ious.global_avg:.3f}' .format(ious=metric_logger.iou)) if args.eval_cd: print('* chamferl2 {chamferl2s.global_avg:.3f}' .format(chamferl2s=metric_logger.chamferl2)) print('* fscore {fscores.global_avg:.3f}' .format(fscores=metric_logger.fscore))