from .arkit import ArkitScene from .blendedmvs import BlendMVS from .co3d import Co3d from .habitat import habitat from .scannet import Scannet from .scannetpp import Scannetpp from .seven_scenes import SevenScenes from .nrgbd import NRGBD from .dtu import DTU from .demo import Demo from dust3r.datasets.utils.transforms import * def get_data_loader(dataset, batch_size, num_workers=8, shuffle=True, drop_last=True, pin_mem=True): import torch from croco.utils.misc import get_world_size, get_rank # pytorch dataset if isinstance(dataset, str): dataset = eval(dataset) world_size = get_world_size() rank = get_rank() try: sampler = dataset.make_sampler(batch_size, shuffle=shuffle, world_size=world_size, rank=rank, drop_last=drop_last) except (AttributeError, NotImplementedError): # not avail for this dataset if torch.distributed.is_initialized(): sampler = torch.utils.data.DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last ) elif shuffle: sampler = torch.utils.data.RandomSampler(dataset) else: sampler = torch.utils.data.SequentialSampler(dataset) data_loader = torch.utils.data.DataLoader( dataset, sampler=sampler, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_mem, drop_last=drop_last, ) return data_loader def build_dataset(dataset, batch_size, num_workers, test=False): split = ['Train', 'Test'][test] print(f'Building {split} Data loader for dataset: ', dataset) loader = get_data_loader(dataset, batch_size=batch_size, num_workers=num_workers, pin_mem=True, shuffle=not (test), drop_last=not (test)) print(f"{split} dataset length: ", len(loader)) return loader