Spaces:
Sleeping
Sleeping
from .arkit import ArkitScene | |
from .blendedmvs import BlendMVS | |
from .co3d import Co3d | |
from .habitat import habitat | |
from .scannet import Scannet | |
from .scannetpp import Scannetpp | |
from .seven_scenes import SevenScenes | |
from .nrgbd import NRGBD | |
from .dtu import DTU | |
from .demo import Demo | |
from dust3r.datasets.utils.transforms import * | |
def get_data_loader(dataset, batch_size, num_workers=8, shuffle=True, drop_last=True, pin_mem=True): | |
import torch | |
from croco.utils.misc import get_world_size, get_rank | |
# pytorch dataset | |
if isinstance(dataset, str): | |
dataset = eval(dataset) | |
world_size = get_world_size() | |
rank = get_rank() | |
try: | |
sampler = dataset.make_sampler(batch_size, shuffle=shuffle, world_size=world_size, | |
rank=rank, drop_last=drop_last) | |
except (AttributeError, NotImplementedError): | |
# not avail for this dataset | |
if torch.distributed.is_initialized(): | |
sampler = torch.utils.data.DistributedSampler( | |
dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last | |
) | |
elif shuffle: | |
sampler = torch.utils.data.RandomSampler(dataset) | |
else: | |
sampler = torch.utils.data.SequentialSampler(dataset) | |
data_loader = torch.utils.data.DataLoader( | |
dataset, | |
sampler=sampler, | |
batch_size=batch_size, | |
num_workers=num_workers, | |
pin_memory=pin_mem, | |
drop_last=drop_last, | |
) | |
return data_loader | |
def build_dataset(dataset, batch_size, num_workers, test=False): | |
split = ['Train', 'Test'][test] | |
print(f'Building {split} Data loader for dataset: ', dataset) | |
loader = get_data_loader(dataset, | |
batch_size=batch_size, | |
num_workers=num_workers, | |
pin_mem=True, | |
shuffle=not (test), | |
drop_last=not (test)) | |
print(f"{split} dataset length: ", len(loader)) | |
return loader | |