Spaces:
Build error
Build error
""" | |
@Date: 2021/07/18 | |
@description: | |
""" | |
import numpy as np | |
import torch.utils.data | |
from dataset.mp3d_dataset import MP3DDataset | |
from dataset.pano_s2d3d_dataset import PanoS2D3DDataset | |
from dataset.pano_s2d3d_mix_dataset import PanoS2D3DMixDataset | |
from dataset.zind_dataset import ZindDataset | |
def build_loader(config, logger): | |
name = config.DATA.DATASET | |
ddp = config.WORLD_SIZE > 1 | |
train_dataset = None | |
train_data_loader = None | |
if config.MODE == 'train': | |
train_dataset = build_dataset(mode='train', config=config, logger=logger) | |
val_dataset = build_dataset(mode=config.VAL_NAME if config.MODE != 'test' else 'test', config=config, logger=logger) | |
train_sampler = None | |
val_sampler = None | |
if ddp: | |
if train_dataset: | |
train_sampler = torch.utils.data.DistributedSampler(train_dataset, shuffle=True) | |
val_sampler = torch.utils.data.DistributedSampler(val_dataset, shuffle=False) | |
batch_size = config.DATA.BATCH_SIZE | |
num_workers = 0 if config.DEBUG else config.DATA.NUM_WORKERS | |
pin_memory = config.DATA.PIN_MEMORY | |
if train_dataset: | |
logger.info(f'Train data loader batch size: {batch_size}') | |
train_data_loader = torch.utils.data.DataLoader( | |
train_dataset, sampler=train_sampler, | |
batch_size=batch_size, | |
shuffle=True, | |
num_workers=num_workers, | |
pin_memory=pin_memory, | |
drop_last=True, | |
) | |
batch_size = batch_size - (len(val_dataset) % np.arange(batch_size, 0, -1)).tolist().index(0) | |
logger.info(f'Val data loader batch size: {batch_size}') | |
val_data_loader = torch.utils.data.DataLoader( | |
val_dataset, sampler=val_sampler, | |
batch_size=batch_size, | |
shuffle=False, | |
num_workers=num_workers, | |
pin_memory=pin_memory, | |
drop_last=False | |
) | |
logger.info(f'Build data loader: num_workers:{num_workers} pin_memory:{pin_memory}') | |
return train_data_loader, val_data_loader | |
def build_dataset(mode, config, logger): | |
name = config.DATA.DATASET | |
if name == 'mp3d': | |
dataset = MP3DDataset( | |
root_dir=config.DATA.DIR, | |
mode=mode, | |
shape=config.DATA.SHAPE, | |
max_wall_num=config.DATA.WALL_NUM, | |
aug=config.DATA.AUG if mode == 'train' else None, | |
camera_height=config.DATA.CAMERA_HEIGHT, | |
logger=logger, | |
for_test_index=config.DATA.FOR_TEST_INDEX, | |
keys=config.DATA.KEYS | |
) | |
elif name == 'pano_s2d3d': | |
dataset = PanoS2D3DDataset( | |
root_dir=config.DATA.DIR, | |
mode=mode, | |
shape=config.DATA.SHAPE, | |
max_wall_num=config.DATA.WALL_NUM, | |
aug=config.DATA.AUG if mode == 'train' else None, | |
camera_height=config.DATA.CAMERA_HEIGHT, | |
logger=logger, | |
for_test_index=config.DATA.FOR_TEST_INDEX, | |
subset=config.DATA.SUBSET, | |
keys=config.DATA.KEYS | |
) | |
elif name == 'pano_s2d3d_mix': | |
dataset = PanoS2D3DMixDataset( | |
root_dir=config.DATA.DIR, | |
mode=mode, | |
shape=config.DATA.SHAPE, | |
max_wall_num=config.DATA.WALL_NUM, | |
aug=config.DATA.AUG if mode == 'train' else None, | |
camera_height=config.DATA.CAMERA_HEIGHT, | |
logger=logger, | |
for_test_index=config.DATA.FOR_TEST_INDEX, | |
subset=config.DATA.SUBSET, | |
keys=config.DATA.KEYS | |
) | |
elif name == 'zind': | |
dataset = ZindDataset( | |
root_dir=config.DATA.DIR, | |
mode=mode, | |
shape=config.DATA.SHAPE, | |
max_wall_num=config.DATA.WALL_NUM, | |
aug=config.DATA.AUG if mode == 'train' else None, | |
camera_height=config.DATA.CAMERA_HEIGHT, | |
logger=logger, | |
for_test_index=config.DATA.FOR_TEST_INDEX, | |
is_simple=True, | |
is_ceiling_flat=False, | |
keys=config.DATA.KEYS, | |
vp_align=config.EVAL.POST_PROCESSING is not None and 'manhattan' in config.EVAL.POST_PROCESSING | |
) | |
else: | |
raise NotImplementedError(f"Unknown dataset: {name}") | |
return dataset | |