bite_gradio / src /stacked_hourglass /datasets /utils_dataset_selection.py
Nadine Rueegg
initial commit with code and data
753fd9a
import torch
from torch.utils.data import DataLoader
import cv2
import glob
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..', 'src'))
def get_evaluation_dataset(cfg_data_dataset, cfg_data_val_opt, cfg_data_V12, cfg_optim_batch_size, args_workers, drop_last=False):
# cfg_data_dataset = cfg.data.DATASET
# cfg_data_val_opt = cfg.data.VAL_OPT
# cfg_data_V12 = cfg.data.V12
# cfg_optim_batch_size = cfg.optim.BATCH_SIZE
# args_workers = args.workers
assert cfg_data_dataset in ['stanext24_easy', 'stanext24', 'stanext24_withgc', 'stanext24_withgc_big']
assert cfg_data_val_opt in ['train', 'test', 'val']
if cfg_data_dataset == 'stanext24_easy':
from stacked_hourglass.datasets.stanext24_easy import StanExtEasy as StanExt
dataset_mode = 'complete'
elif cfg_data_dataset == 'stanext24':
from stacked_hourglass.datasets.stanext24 import StanExt
dataset_mode = 'complete'
elif cfg_data_dataset == 'stanext24_withgc':
from stacked_hourglass.datasets.stanext24_withgc import StanExtGC as StanExt
dataset_mode = 'complete_with_gc'
elif cfg_data_dataset == 'stanext24_withgc_big':
from stacked_hourglass.datasets.stanext24_withgc_v2 import StanExtGC as StanExt
dataset_mode = 'complete_with_gc'
# Initialise the validation set dataloader
if cfg_data_val_opt == 'test':
val_dataset = StanExt(image_path=None, is_train=False, dataset_mode=dataset_mode, V12=cfg_data_V12, val_opt='test')
test_name_list = val_dataset.test_name_list
elif cfg_data_val_opt == 'val':
val_dataset = StanExt(image_path=None, is_train=False, dataset_mode=dataset_mode, V12=cfg_data_V12, val_opt='val')
test_name_list = val_dataset.test_name_list
elif cfg_data_val_opt == 'train':
val_dataset = StanExt(image_path=None, is_train=True, do_augment='no', dataset_mode=dataset_mode, V12=cfg_data_V12)
test_name_list = val_dataset.train_name_list
else:
raise ValueError
val_loader = DataLoader(val_dataset, batch_size=cfg_optim_batch_size, shuffle=False,
num_workers=args_workers, pin_memory=True, drop_last=drop_last) # False) # , drop_last=True args.batch_size
len_val_dataset = len(val_dataset)
stanext_data_info = StanExt.DATA_INFO
stanext_acc_joints = StanExt.ACC_JOINTS
return val_dataset, val_loader, len_val_dataset, test_name_list, stanext_data_info, stanext_acc_joints
def get_sketchfab_evaluation_dataset(cfg_optim_batch_size, args_workers):
# cfg_optim_batch_size = cfg.optim.BATCH_SIZE
# args_workers = args.workers
from stacked_hourglass.datasets.sketchfab import SketchfabScans
val_dataset = SketchfabScans(image_path=None, is_train=False, dataset_mode='complete')
test_name_list = val_dataset.test_name_list
val_loader = DataLoader(val_dataset, batch_size=cfg_optim_batch_size, shuffle=False,
num_workers=args_workers, pin_memory=True, drop_last=False) # drop_last=True)
from stacked_hourglass.datasets.stanext24 import StanExt
len_val_dataset = len(val_dataset)
stanext_data_info = StanExt.DATA_INFO
stanext_acc_joints = StanExt.ACC_JOINTS
return val_dataset, val_loader, len_val_dataset, test_name_list, stanext_data_info, stanext_acc_joints
def get_crop_evaluation_dataset(cfg_optim_batch_size, args_workers, input_folder):
from stacked_hourglass.datasets.imgcropslist import ImgCrops
image_list_paths = glob.glob(os.path.join(input_folder, '*.jpg')) + glob.glob(os.path.join(input_folder, '*.png'))
image_list = []
test_name_list = []
for image_path in image_list_paths:
test_name_list.append(os.path.basename(image_path).split('.')[0])
img = cv2.imread(image_path)
image_list.append(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
val_dataset = ImgCrops(image_list=image_list, bbox_list=None)
val_loader = DataLoader(val_dataset, batch_size=cfg_optim_batch_size, shuffle=False,
num_workers=args_workers, pin_memory=True, drop_last=False) # drop_last=True)
from stacked_hourglass.datasets.stanext24 import StanExt
len_val_dataset = len(val_dataset)
stanext_data_info = StanExt.DATA_INFO
stanext_acc_joints = StanExt.ACC_JOINTS
return val_dataset, val_loader, len_val_dataset, test_name_list, stanext_data_info, stanext_acc_joints
def get_single_crop_dataset_from_image(input_image, bbox=None):
from stacked_hourglass.datasets.imgcropslist import ImgCrops
input_image_list = [input_image]
if bbox is not None:
input_bbox_list = [bbox]
else:
input_bbox_list = None
# prepare data loader
val_dataset = ImgCrops(image_list=input_image_list, bbox_list=input_bbox_list, dataset_mode='complete')
test_name_list = val_dataset.test_name_list
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False,
num_workers=0, pin_memory=True, drop_last=False)
from stacked_hourglass.datasets.stanext24 import StanExt
len_val_dataset = len(val_dataset)
stanext_data_info = StanExt.DATA_INFO
stanext_acc_joints = StanExt.ACC_JOINTS
return val_dataset, val_loader, len_val_dataset, test_name_list, stanext_data_info, stanext_acc_joints
def get_norm_dict(data_info=None, device="cuda"):
if data_info is None:
from stacked_hourglass.datasets.stanext24 import StanExt
data_info = StanExt.DATA_INFO
norm_dict = {
'pose_rot6d_mean': torch.from_numpy(data_info.pose_rot6d_mean).float().to(device),
'trans_mean': torch.from_numpy(data_info.trans_mean).float().to(device),
'trans_std': torch.from_numpy(data_info.trans_std).float().to(device),
'flength_mean': torch.from_numpy(data_info.flength_mean).float().to(device),
'flength_std': torch.from_numpy(data_info.flength_std).float().to(device)}
return norm_dict