|
import torch.utils.data |
|
|
|
from .SingleView_dataset import Object_Occ,Object_PartialPoints_MultiImg |
|
from .transforms import Scale_Shift_Rotate,Aug_with_Tran, Augment_Points |
|
from .taxonomy import synthetic_category_combined,synthetic_arkit_category_combined,arkit_category |
|
|
|
def build_object_occ_dataset(split,args): |
|
transform = Scale_Shift_Rotate(rot_shift_surface=True,use_scale=True,use_whole_scale=True) |
|
category=args['category'] |
|
|
|
category_list=synthetic_arkit_category_combined[category] |
|
replica=args['replica'] |
|
if split == "train": |
|
return Object_Occ(args['data_path'], split=split, categories=category_list, |
|
transform=transform, sampling=True, |
|
num_samples=args['num_samples'], return_surface=True, |
|
surface_sampling=True, surface_size=args['surface_size'],replica=replica) |
|
elif split == "val": |
|
return Object_Occ(args['data_path'], split=split,categories=category_list, |
|
transform=transform, sampling=False, |
|
num_samples=args['num_samples'], return_surface=True, |
|
surface_sampling=True,surface_size=args['surface_size'], replica=1) |
|
|
|
def build_par_multiimg_dataset(split,args): |
|
|
|
|
|
transform=Aug_with_Tran(par_jitter_sigma=args['jitter_partial_train']) |
|
val_transform=Aug_with_Tran(par_jitter_sigma=args['jitter_partial_val']) |
|
category=args['category'] |
|
category_list=synthetic_category_combined[category] |
|
if split == "train": |
|
return Object_PartialPoints_MultiImg(args['data_path'], split_filename="train_par_img.json",split=split, |
|
categories=category_list, |
|
transform=transform, sampling=True, |
|
num_samples=1024, return_surface=False,ret_sample=False, |
|
surface_sampling=True, par_pc_size=args['par_pc_size'],surface_size=args['surface_size'], |
|
load_proj_mat=args['load_proj_mat'],load_image=args['load_image'],load_triplane=True, |
|
par_prefix=args['par_prefix'],par_point_aug=args['par_point_aug'],replica=args['replica'], |
|
num_objects=args['num_objects']) |
|
elif split =="val": |
|
return Object_PartialPoints_MultiImg(args['data_path'], split_filename="val_par_img.json",split=split, |
|
categories=category_list, |
|
transform=val_transform, sampling=False, |
|
num_samples=1024, return_surface=False,ret_sample=True, |
|
surface_sampling=True, par_pc_size=args['par_pc_size'],surface_size=args['surface_size'], |
|
load_proj_mat=args['load_proj_mat'],load_image=args['load_image'],load_triplane=True, |
|
par_prefix=args['par_prefix'],par_point_aug=None,replica=1) |
|
|
|
def build_finetune_par_multiimg_dataset(split,args): |
|
|
|
|
|
keyword=args['keyword'] |
|
pretrain_transform=Aug_with_Tran(par_jitter_sigma=args['jitter_partial_pretrain']) |
|
finetune_transform=Aug_with_Tran(par_jitter_sigma=args['jitter_partial_finetune']) |
|
val_transform=Aug_with_Tran(par_jitter_sigma=args['jitter_partial_val']) |
|
|
|
pretrain_cat=synthetic_category_combined[args['category']] |
|
arkit_cat=arkit_category[args['category']] |
|
use_pretrain_data=args["use_pretrain_data"] |
|
|
|
if split == "train": |
|
if use_pretrain_data: |
|
pretrain_dataset=Object_PartialPoints_MultiImg(args['data_path'], split_filename="train_par_img.json",categories=pretrain_cat, |
|
split=split,transform=pretrain_transform, sampling=True,num_samples=1024, return_surface=False,ret_sample=False, |
|
surface_sampling=True, par_pc_size=args['par_pc_size'],surface_size=args['surface_size'], |
|
load_proj_mat=args['load_proj_mat'],load_image=args['load_image'],load_triplane=True,par_point_aug=args['par_point_aug'], |
|
par_prefix=args['par_prefix'],replica=1) |
|
finetune_dataset=Object_PartialPoints_MultiImg(args['data_path'], split_filename=keyword+"_train_par_img.json",categories=arkit_cat, |
|
split=split,transform=finetune_transform, sampling=True,num_samples=1024, return_surface=False,ret_sample=False, |
|
surface_sampling=True, par_pc_size=args['par_pc_size'],surface_size=args['surface_size'], |
|
load_proj_mat=args['load_proj_mat'],load_image=args['load_image'],load_triplane=True,par_point_aug=None,replica=args['replica']) |
|
if use_pretrain_data: |
|
return torch.utils.data.ConcatDataset([pretrain_dataset,finetune_dataset]) |
|
else: |
|
return finetune_dataset |
|
elif split =="val": |
|
return Object_PartialPoints_MultiImg(args['data_path'], split_filename=keyword+"_val_par_img.json",categories=arkit_cat,split=split, |
|
transform=val_transform, sampling=False, |
|
num_samples=1024, return_surface=False,ret_sample=True, |
|
surface_sampling=True, par_pc_size=args['par_pc_size'],surface_size=args['surface_size'], |
|
load_proj_mat=args['load_proj_mat'],load_image=args['load_image'],load_triplane=True,par_point_aug=None,replica=1) |
|
|
|
def build_dataset(split,args): |
|
if args['type']=="Occ": |
|
return build_object_occ_dataset(split,args) |
|
elif args['type']=="Occ_Par_MultiImg": |
|
return build_par_multiimg_dataset(split,args) |
|
elif args['type']=="Occ_Par_MultiImg_Finetune": |
|
return build_finetune_par_multiimg_dataset(split,args) |
|
else: |
|
raise NotImplementedError |